-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve fake mode support by adding fake_context to ExportOutput #105247
Improve fake mode support by adding fake_context to ExportOutput #105247
Conversation
Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105247
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5549bb6: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. ghstack-source-id: 0a5be659b2104d6723df86bb62ef7e9bfa3298ba Pull Request resolved: #105247
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. [ghstack-poisoned]
Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. ghstack-source-id: 14b67f83451df103ac4efc2652eb588cc6967d2b Pull Request resolved: #105247
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. ps: #105464 tracks pending tasks/limitations from this PR [ghstack-poisoned]
…Output" Prior to this PR, if the user called `fake_model.load_state_dict()` from within `enable_fake_mode`, the additional state_dict would not be reused by `ExportOutput.save` during ONNX proto creation. That is not necessarily a problem because `ExportOutput.save` has a `model_state_dict` in which they can specify any state they want. With this PR, the `fake_context` specified as `ExportOption` to `torch.onnx.dynamo_export` is saved at `ExportOutput` and used by `save` to load additional `state_dict` that was captured by `enable_fake_mode` This helps addressing #105233 for models that call `fake_model.load _state_dict` under the hood as potential buffers not returned by `model.state_dict()` may be captured. ps: #105464 tracks pending tasks/limitations from this PR [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I'm wondering if the larger problem is about how to handle the tensor constants that are neither parameters nor persistent buffers under the model. A few questions
- Are non-persistent buffers the only case left for tensor constants that are not parameters or persistent buffers?
- What about only fakefying the things inside the checkpoint? That implies we set
allow_non_fake_inputs
and the issue is gone?
# E.g., "tensor" is stored as the initializer of "attention_weight". | ||
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary | ||
# loaded by torch.load. | ||
for onnx_input_name in onnx_input_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this change necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added one more line to the previous comment block to explain the reasoning. Let me know if this is not sufficient
# This block tries to match the onnx initializer name with torch parameter/buffer
This is not actually related to the initializer renaming controlled by rename_initializer
, but actually tries to match the torch name (from a state_dict) to the ONNX initializer name. For tiny GPT2, for example, the torch names are prefixed with transformer.
while the ONNX names aren't, with or without the initializer rename feature (that only replaces .
by _
.
With the name matching, the onnx graph will have all initializer saved within the proto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the name prefixing was only for symbolic tracer graph. This is unexpected for dynamo graph, since the restore naming pass in readability.py
retrieves names from both named_parameters
and named_buffer
. If there is any transformer.
prefix for torch name it should be there for ONNX initializer name as well. I wonder what went wrong in the middle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, the trasnformer.
prefix came from the checkpoint file, not the actual torch model names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
loaded by torch.load.
Hmm.. ok, so names in ckpt are sometimes different to names in named_parameters
and named_buffers
. I think there could be a more systematic fix then. Please also track this as one of the issues.
Can you help me point to the code where the non-persistent buffer data, that is not returned by |
There is no such code :) Before and after this PR, it is the user's responsibility to provide a full state to the exporter, if one is not available. In this corner case, (e.g. model with non-persistent buffers and no checkpoint with state), they have to provide it with something like state_dict = real_model.state_dict()
state_dict.update(dict(real_model.named_buffers())) # Save non-persistent buffers along with the rest of the model state_dict If the user decides that a buffer is not to be persistent, there is nothing we can do about it unless they provide a checkpoint with the full state. This PR improves our export process when a checkpoint (not state_dict) is present during initialization. From now one, we can capture the full state of a model when one is present in the checkpoint, without user intervention. This is done by moving the We can create a separate task to try to look for missing state and raise a friendly error to the user as early as possible in the export process, but it shouldnt be "normal" for users to instantiate models with "incomplete" state |
We will learn about corner cases and fix them as we use this feature more often. We didn't know about the non-persistent buffers until a few days ago, but now we have an official way to capture the non persistent buffer from checkpoint and a workaround for users to manually specify them (if ever needed)
That issue is not tackled by this PR. Here all we care about is to capture a state when one is available. Before this PR, we discarded state that was given for free and then required the user to specify it. #105477 gives a first shot at that and already allow |
I see, so essentially |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably follow up on the naming prefix issue, others look good.
Added #105751 to my task list to address this |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Prior to this PR, if the user called
fake_model.load_state_dict()
from withinenable_fake_mode
, the initial model state dict (including non persistent buffers) would not be reused byExportOutput.save
during ONNX proto creation.That is not necessarily a bug because
ExportOutput.save
has amodel_state_dict
in which they can specify any state they want. However, it can be a hassle because if the user doesn't provide a full state, including non-persistent buffers, the resulting ONNX graph would require the missing buffers to be specified as input during execution.With this PR, the
enable_fake_mode
is improved to capture the initial model state including any non-persistent buffer. This reference (not actual data) is persisted withinExportOutput
and used bysave
to load additionalstate_dict
that was captured byenable_fake_mode
. The result is an ONNX graph with all model state without user having to specify the non-persistent buffers.This helps addressing #105233 for models that call
fake_model.load _state_dict
under the hood as potential buffers not returned bymodel.state_dict()
may be captured.ps: #105464 tracks pending tasks/limitations from this PR