Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] fix graph break problems partially #5453

Merged
merged 14 commits into from
Oct 23, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Oct 19, 2023

Fixes graph break problems for T2I Adapters (both SD and SDXL).

ControlNets are still failing. I am trying to get to the bottom of it.

@DN6, if you want to double-check the fixes proposed in this PR, I'd appreciate it.

Comment on lines -849 to -878
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# Forward upsample size to force interpolation output size.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile() fails to compile these kinds of iterators right now.

@sayakpaul
Copy link
Member Author

The failing test seems unrelated to the PR.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines 980 to 990
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
"2.1"
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1:
torch._inductor.cudagraph_mark_step_begin()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is in some sense a breaking change from PT, do we really have to add version specific code here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no other way to support compiled ControlNets otherwise in PT 2.1, sadly.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general,just some nits. Let's also flag this issue for PT to take a look. Torch compile seems to be backward broken here between 2.1 and 2.0

src/diffusers/pipelines/controlnet/pipeline_controlnet.py Outdated Show resolved Hide resolved
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1:
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:

Do we need this really? It's not super pretty and looks like a bug in PT 2.1 . Also are we sure the code works fine with PT 2.0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed internally.

@sayakpaul
Copy link
Member Author

@patrickvonplaten let me know your final thoughts on merging this PR. Personally, I am okay hearing back from the PT folks and then deciding the course of action here. But I guess, it's good for us to be at least aware of the situation and a reasonable workaround.

@patrickvonplaten
Copy link
Contributor

Ok to merge from my side. It'll take a while until this would be fixed in PT 2.1, so think no matter what we should merge this. Great job!

@sayakpaul sayakpaul merged commit 48ce118 into main Oct 23, 2023
13 checks passed
@sayakpaul sayakpaul deleted the fix/controlnet-graph-break branch October 23, 2023 18:11
linoytsaban pushed a commit to linoytsaban/diffusers that referenced this pull request Oct 24, 2023
* fix: controlnet graph?

* fix: sample

* fix:

* remove print

* styling

* fix-copies

* prevent more graph breaks?

* prevent more graph breaks?

* see?

* revert.

* compilation.

* rpopagate changes to controlnet sdxl pipeline too.

* add: clean version checking.
kashif pushed a commit to kashif/diffusers that referenced this pull request Nov 11, 2023
* fix: controlnet graph?

* fix: sample

* fix:

* remove print

* styling

* fix-copies

* prevent more graph breaks?

* prevent more graph breaks?

* see?

* revert.

* compilation.

* rpopagate changes to controlnet sdxl pipeline too.

* add: clean version checking.
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* fix: controlnet graph?

* fix: sample

* fix:

* remove print

* styling

* fix-copies

* prevent more graph breaks?

* prevent more graph breaks?

* see?

* revert.

* compilation.

* rpopagate changes to controlnet sdxl pipeline too.

* add: clean version checking.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* fix: controlnet graph?

* fix: sample

* fix:

* remove print

* styling

* fix-copies

* prevent more graph breaks?

* prevent more graph breaks?

* see?

* revert.

* compilation.

* rpopagate changes to controlnet sdxl pipeline too.

* add: clean version checking.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants