Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bitsandbytes] allow directly CUDA placements of pipelines loaded with bnb components #9840

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

What does this PR do?

When a pipeline is loaded with models that have quantization config, we should still be able to call to("cuda") on the pipeline object. For GPUs that would allow the memory (such as a 4090), this has performance benefits (as demonstrated below).

Model CPU Offload Batch Size Time (seconds) Memory (GB)
False 1 19.316 14.935
True 1 36.746 12.139
False 4 80.665 20.576
True 4 98.612 12.138

Flux.1 Dev, steps: 30

Currently, calling to("cuda") is not possible because:

from transformers import T5EncoderModel
from transformers import BitsAndBytesConfig as BnbConfig
import torch 

ckpt_id = "black-forest-labs/FLUX.1-dev"

text_encoder_2_config = BnbConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
text_encoder_2 = T5EncoderModel.from_pretrained(
    ckpt_id,
    subfolder="text_encoder_2",
    quantization_config=text_encoder_2_config,
    torch_dtype=torch.bfloat16
)
print(text_encoder_2._hf_hook)

has:

AlignDevicesHook(execution_device=0, offload=False, io_same_device=True, offload_buffers=False, place_submodules=True, skip_keys=None)

This is why this line complains:

if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":

This PR fixes that behavior.

Benchmarking code:

Unroll
from diffusers import DiffusionPipeline, FluxTransformer2DModel, BitsAndBytesConfig
from transformers import T5EncoderModel
from transformers import BitsAndBytesConfig as BnbConfig
import torch.utils.benchmark as benchmark
import torch 
import fire

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=torch.get_num_threads(),
    )
    return f"{(t0.blocked_autorange().mean):.3f}"

def load_pipeline(model_cpu_offload=False):
    ckpt_id = "black-forest-labs/FLUX.1-dev"

    transformer_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        ckpt_id, 
        subfolder="transformer",
        quantization_config=transformer_config,
        torch_dtype=torch.bfloat16
    )

    text_encoder_2_config = BnbConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    text_encoder_2 = T5EncoderModel.from_pretrained(
        ckpt_id,
        subfolder="text_encoder_2",
        quantization_config=text_encoder_2_config,
        torch_dtype=torch.bfloat16
    )

    pipeline = DiffusionPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        text_encoder_2=text_encoder_2,
        transformer=transformer,
        torch_dtype=torch.bfloat16,
    )
    if model_cpu_offload:
        pipeline.enable_model_cpu_offload()
    else:
        pipeline = pipeline.to("cuda")

    pipeline.set_progress_bar_config(disable=True)
    return pipeline

def run_pipeline(pipeline, batch_size=1):
    _ = pipeline(
        prompt="a dog sitting besides a sea", 
        guidance_scale=3.5, 
        max_sequence_length=512, 
        num_inference_steps=30,
        num_images_per_prompt=batch_size
    )


def main(batch_size: int = 1, model_cpu_offload: bool = False):
    pipeline = load_pipeline(model_cpu_offload=model_cpu_offload)

    for _ in range(5):
        run_pipeline(pipeline)

    time = benchmark_fn(run_pipeline, pipeline, batch_size)
    memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
    print(f"{model_cpu_offload=}, {batch_size=} {time=} seconds {memory=} GB.")

    image = pipeline(
        prompt="a dog sitting besides a sea", 
        guidance_scale=3.5, 
        max_sequence_length=512, 
        num_inference_steps=30,
        num_images_per_prompt=1
    ).images[0]
    img_name = f"mco@{model_cpu_offload}-bs@{batch_size}.png"
    image.save(img_name)


if __name__ == "__main__":
    fire.Fire(main)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR ! Left a suggestion

src/diffusers/pipelines/pipeline_utils.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

@SunMarc WDYT now?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this ! LGTM ! I'll marge the PR on accelerate also

@sayakpaul
Copy link
Member Author

Have run the integration tests and they are passing.

@SunMarc
Copy link
Member

SunMarc commented Nov 18, 2024

Have run the integration tests and they are passing.
On diffusers ?

@sayakpaul
Copy link
Member Author

@SunMarc yes, on diffusers. Anywhere else they need to be run?

@SunMarc
Copy link
Member

SunMarc commented Nov 18, 2024

No, I read that as a question, my bad ;)

# For `diffusers` it should not be a problem as we enforce the installation of a bnb version
# that already supports CPU placements.
else:
module.to(device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok but this means for diffusers the transformer version would always met the requirement, no? i.e. the check is_transformers_version(">", "4.44.0") will aways pass

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree but for the diffusers codepath, probably don't care about the transformer version, no?

Anything I am missing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants