Skip to content

Commit

Permalink
Use torch.nn.functional.linear in RGB preview code.
Browse files Browse the repository at this point in the history
Add an optional bias to the latent RGB preview code.
  • Loading branch information
comfyanonymous committed Sep 29, 2024
1 parent 3bb4dec commit a9e459c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
1 change: 1 addition & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None

def process_in(self, latent):
Expand Down
16 changes: 12 additions & 4 deletions latent_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,20 @@ def decode_latent_to_preview(self, x0):


class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")

def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)

latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors

return preview_to_image(latent_image)


Expand Down Expand Up @@ -71,7 +79,7 @@ def get_previewer(device, latent_format):

if previewer is None:
if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
return previewer

def prepare_callback(model, steps, x0_output_dict=None):
Expand Down

0 comments on commit a9e459c

Please sign in to comment.