Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve fake mode support by adding fake_context to ExportOutput #105247

Closed

Conversation

thiagocrepaldi
Copy link
Collaborator

@thiagocrepaldi thiagocrepaldi commented Jul 14, 2023

Stack from ghstack (oldest at bottom):

Prior to this PR, if the user called fake_model.load_state_dict() from within enable_fake_mode, the initial model state dict (including non persistent buffers) would not be reused by ExportOutput.save during ONNX proto creation.

That is not necessarily a bug because ExportOutput.save has a model_state_dict in which they can specify any state they want. However, it can be a hassle because if the user doesn't provide a full state, including non-persistent buffers, the resulting ONNX graph would require the missing buffers to be specified as input during execution.

With this PR, the enable_fake_mode is improved to capture the initial model state including any non-persistent buffer. This reference (not actual data) is persisted within ExportOutput and used by save to load additional state_dict that was captured by enable_fake_mode. The result is an ONNX graph with all model state without user having to specify the non-persistent buffers.

This helps addressing #105233 for models that call fake_model.load _state_dict under the hood as potential buffers not returned by model.state_dict() may be captured.

ps: #105464 tracks pending tasks/limitations from this PR

Prior to this PR, if the user called `fake_model.load_state_dict()`
from within `enable_fake_mode`, the additional state_dict would not be
reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a
`model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to
`torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save`
to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233
for models that call `fake_model.load _state_dict` under the hood as
potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105247

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5549bb6:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Jul 14, 2023
@thiagocrepaldi thiagocrepaldi added the module: onnx Related to torch.onnx label Jul 14, 2023
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
thiagocrepaldi pushed a commit that referenced this pull request Jul 15, 2023
Prior to this PR, if the user called `fake_model.load_state_dict()`
from within `enable_fake_mode`, the additional state_dict would not be
reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a
`model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to
`torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save`
to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233
for models that call `fake_model.load _state_dict` under the hood as
potential buffers not returned by `model.state_dict()` may be captured.

ghstack-source-id: 0a5be659b2104d6723df86bb62ef7e9bfa3298ba
Pull Request resolved: #105247
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

[ghstack-poisoned]
thiagocrepaldi pushed a commit that referenced this pull request Jul 18, 2023
Prior to this PR, if the user called `fake_model.load_state_dict()`
from within `enable_fake_mode`, the additional state_dict would not be
reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a
`model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to
`torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save`
to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233
for models that call `fake_model.load _state_dict` under the hood as
potential buffers not returned by `model.state_dict()` may be captured.

ghstack-source-id: 14b67f83451df103ac4efc2652eb588cc6967d2b
Pull Request resolved: #105247
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

ps: #105464 tracks pending tasks/limitations from this PR

[ghstack-poisoned]
…Output"


Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation.

That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want.

With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode`

This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured.

ps: #105464 tracks pending tasks/limitations from this PR

[ghstack-poisoned]
Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'm wondering if the larger problem is about how to handle the tensor constants that are neither parameters nor persistent buffers under the model. A few questions

  • Are non-persistent buffers the only case left for tensor constants that are not parameters or persistent buffers?
  • What about only fakefying the things inside the checkpoint? That implies we set allow_non_fake_inputs and the issue is gone?

test/onnx/test_fx_to_onnx_with_onnxruntime.py Show resolved Hide resolved
# E.g., "tensor" is stored as the initializer of "attention_weight".
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
# loaded by torch.load.
for onnx_input_name in onnx_input_names:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change necessary?

Copy link
Collaborator Author

@thiagocrepaldi thiagocrepaldi Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added one more line to the previous comment block to explain the reasoning. Let me know if this is not sufficient

# This block tries to match the onnx initializer name with torch parameter/buffer

This is not actually related to the initializer renaming controlled by rename_initializer, but actually tries to match the torch name (from a state_dict) to the ONNX initializer name. For tiny GPT2, for example, the torch names are prefixed with transformer. while the ONNX names aren't, with or without the initializer rename feature (that only replaces . by _.

With the name matching, the onnx graph will have all initializer saved within the proto

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the name prefixing was only for symbolic tracer graph. This is unexpected for dynamo graph, since the restore naming pass in readability.py retrieves names from both named_parameters and named_buffer. If there is any transformer. prefix for torch name it should be there for ONNX initializer name as well. I wonder what went wrong in the middle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the trasnformer. prefix came from the checkpoint file, not the actual torch model names.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
loaded by torch.load.

Hmm.. ok, so names in ckpt are sometimes different to names in named_parameters and named_buffers. I think there could be a more systematic fix then. Please also track this as one of the issues.

@BowenBao
Copy link
Collaborator

BowenBao commented Jul 21, 2023

Can you help me point to the code where the non-persistent buffer data, that is not returned by state_dict(), is recorded?

@thiagocrepaldi
Copy link
Collaborator Author

thiagocrepaldi commented Jul 21, 2023

Can you help me point to the code where the non-persistent buffer data, that is not returned by state_dict(), is recorded?

There is no such code :) Before and after this PR, it is the user's responsibility to provide a full state to the exporter, if one is not available. In this corner case, (e.g. model with non-persistent buffers and no checkpoint with state), they have to provide it with something like

state_dict = real_model.state_dict()
state_dict.update(dict(real_model.named_buffers()))  # Save non-persistent buffers along with the rest of the model state_dict

If the user decides that a buffer is not to be persistent, there is nothing we can do about it unless they provide a checkpoint with the full state.

This PR improves our export process when a checkpoint (not state_dict) is present during initialization. From now one, we can capture the full state of a model when one is present in the checkpoint, without user intervention. This is done by moving the ONNXPytorchPatcher from the ExportOutput.save to torch.onnx.enable_fake_mode. This way, when user calls say transformers.AutoModel.from_pretrained(model_name) or model.load_state_dict(), we will be able to capture such (full) state and create initializers later on with it during ExportOutput.save. That is exactly what happens with Hugging Face models (it downloads checkpoints during init), so I managed to remove the state_dict.update(dict(real_model.named_buffers())) from the unit tests because we captured the state from hugging face cache which contains the buffers required to run the model.

We can create a separate task to try to look for missing state and raise a friendly error to the user as early as possible in the export process, but it shouldnt be "normal" for users to instantiate models with "incomplete" state

@thiagocrepaldi
Copy link
Collaborator Author

thiagocrepaldi commented Jul 21, 2023

Hmm I'm wondering if the larger problem is about how to handle the tensor constants that are neither parameters nor persistent buffers under the model. A few questions

  • Are non-persistent buffers the only case left for tensor constants that are not parameters or persistent buffers?

We will learn about corner cases and fix them as we use this feature more often. We didn't know about the non-persistent buffers until a few days ago, but now we have an official way to capture the non persistent buffer from checkpoint and a workaround for users to manually specify them (if ever needed)

  • What about only fakefying the things inside the checkpoint? That implies we set allow_non_fake_inputs and the issue is gone?

That issue is not tackled by this PR. Here all we care about is to capture a state when one is available. Before this PR, we discarded state that was given for free and then required the user to specify it.

#105477 gives a first shot at that and already allow allow_non_fake_inputs=False for some scenarios

@BowenBao
Copy link
Collaborator

I see, so essentially transformers.AutoModel.from_pretrained(model_name) and model.load_state_dict() creates (or loads?) non-persistent buffer, and the ONNXPytorchPatcher hijacks and keeps track of it.

Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably follow up on the naming prefix issue, others look good.

@thiagocrepaldi
Copy link
Collaborator Author

We should probably follow up on the naming prefix issue, others look good.

Added #105751 to my task list to address this

@thiagocrepaldi
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 21, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/thiagocrepaldi/6/head branch July 25, 2023 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants