Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not merge] testing rotary embedding + torch.compile #9321

Closed
wants to merge 4 commits into from

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Aug 29, 2024

this is script I use to generate trace

import os
os.environ['TORCH_LOGS'] = 'graph_breaks, dynamo, recompiles'
os.environ['TORCHDYNAMO_VERBOSE'] = '1'

import torch

import logging
logging.basicConfig(level=logging.INFO)

from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler
import gc


torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True

import diffusers
from platform import python_version
from diffusers import DiffusionPipeline

print(diffusers.__version__)
print(torch.__version__)
print(python_version())

def profiler_runner(fn, *args, **kwargs):
    with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA],
            record_shapes=True,
            on_trace_ready=tensorboard_trace_handler("./yiyi_trace")) as prof:
        result = fn(*args, **kwargs)
    return result


pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt_embeds = torch.load("flux_prompt_embeds.pt")
pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")

def run_inference(pipe):
    for i in range(5):
        with record_function(f"pipeline_run_number_{i}"):
            _ = pipe(
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                num_inference_steps=3,
                guidance_scale=3.5,
                max_sequence_length=512,
                generator=torch.manual_seed(42),
                height=1024,
                width=1024,
            )

_ = profiler_runner(run_inference, pipe)
benchmark test

I'm using this testing scripts, some numbers below
( I don't think putting arrange on device improved anything, but there is a difference from 0.30.1-patch so should look into that)

main before the lastest commit (with graph break warning)

# Execution time: 2.287 sec
# Memory: 22.805 gib

main with latest commit (#9307)

# Execution time: 2.256 sec
# Memory: 22.346 gib

in the PR (do arange on gpu)

# Execution time: 2.269 sec
# Memory: 22.348 gib

0.30.1-patch (using the original flux rotary embeds, before this PR #9074)

# Execution time: 2.226 sec
# Memory: 22.346 gib

testing script

import os
os.environ['TORCH_LOGS'] = 'graph_breaks'

import torch
import torch.utils.benchmark as benchmark
import gc

import time

torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True

import diffusers
from platform import python_version
from diffusers import DiffusionPipeline

print(diffusers.__version__)
print(torch.__version__)
print(python_version())

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=torch.get_num_threads(),
    )
    return f"{(t0.blocked_autorange().mean):.3f}"

def bytes_to_giga_bytes(bytes):
    return f"{(bytes / 1024 / 1024 / 1024):.3f}"

def flush():
    """Wipes off memory."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt_embeds = torch.load("flux_prompt_embeds.pt")
pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")

def run_inference(pipe):
    _ = pipe(
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        num_inference_steps=5,
        guidance_scale=3.5,
        max_sequence_length=512,
        generator=torch.manual_seed(42),
        height=1024,
        width=1024,
    )

flush()

for _ in range(5):
    run_inference(pipe)

flush()

time = benchmark_fn(run_inference, pipe)
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())  # in GBs.
print(f" Execution time: {time} sec")
print(f" Memory: {memory} gib")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu marked this pull request as draft August 30, 2024 19:26
@@ -716,21 +718,24 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be another candidate for syncs because we're iterating over a tensor here. See this internal link for more details.

@cpuhrsch
Copy link

cpuhrsch commented Aug 31, 2024

I'd call pipe a few times (say 5) within run_inference to make sure warmup and initial setup cost isn't being captured. You can also wrap it in a record_function context and annotate each iteration just to make it more obvious in the resulting trace.

def run_inference(pipe):
    for i in range(5):
        with torch.autograd.profiler.record_function(f"Run_number_{i}"):
            _ = pipe(
    [...]

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 29, 2024
@yiyixuxu yiyixuxu closed this Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants