Skip to content

Commit

Permalink
Update final CPU offloading code for more diffusion pipelines (huggin…
Browse files Browse the repository at this point in the history
…gface#5589)

* Update final model offload for more pipelines

Add test to ensure all pipeline components are returned to CPU after
execution with model offloading

* Add comment to explain early UNet offload in Text-to-Video pipeline

* Style
  • Loading branch information
clarencechen authored and kashif committed Nov 11, 2023
1 parent f7a3a86 commit 96c1b9e
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1604,9 +1604,8 @@ def denoising_value_valid(dnv):

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1433,9 +1433,8 @@ def __call__(

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -864,9 +864,8 @@ def __call__(

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1031,9 +1031,8 @@ def __call__(

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,8 @@ def __call__(
if output_type == "pil" and self.watermarker is not None:
image = self.watermarker.apply_watermark(image)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,8 @@ def __call__(

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,8 @@ def __call__(

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,9 +1059,8 @@ def __call__(

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ def __call__(
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)

# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")

Expand Down
8 changes: 8 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,14 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):

max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
self.assertTrue(
all(
v.device == "cpu"
for k, v in pipe.components.values()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
),
"CPU offloading should leave all pipeline components on the CPU after inference",
)

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
Expand Down

0 comments on commit 96c1b9e

Please sign in to comment.