Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于推理的问题 #7

Open
prefixRAINSTARsuffix opened this issue Jan 23, 2024 · 2 comments
Open

关于推理的问题 #7

prefixRAINSTARsuffix opened this issue Jan 23, 2024 · 2 comments

Comments

@prefixRAINSTARsuffix
Copy link

作者你好,拜读了你的文章,我对训练推理过程有几个问题:

  1. 训练是用 concat[x, x_m, z_t] 来预测 z_t 的 noise,这样理解对吗 ?

  2. 推理的时候输入是 concat[x, x_m, z_T], 那么这一过程是对谁去噪呢 ? 是对 z_T 还是对 concat[x, x_m, z_T]

  3. 一般 SD 是直接采样一个噪声作为初始输入,我推理的时候直接把 z_T 换成一个随机噪声,还能达到原来的效果吗 (按我的理解 z_T 应该等价于完全噪声吧)

@chenhaoxing
Copy link
Owner

class StableDiffusionDocPipeline(DiffusionPipeline):
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: VisionEncoderDecoderModel,
        tokenizer: TrOCRProcessor,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
    ):
        super().__init__()
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
        )

        self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)

    @property
    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
    def _execution_device(self):
        r"""
        Returns the device on which the pipeline's models will be executed. After calling
        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
        hooks.
        """
        if not hasattr(self.unet, "_hf_hook"):
            return self.device
        for module in self.unet.modules():
            if (
                hasattr(module, "_hf_hook")
                and hasattr(module._hf_hook, "execution_device")
                and module._hf_hook.execution_device is not None
            ):
                return torch.device(module._hf_hook.execution_device)
        return self.device
        
    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[torch.FloatTensor, PIL.Image.Image] = None,
        location: Optional[np.ndarray]= None,
        image: Union[torch.FloatTensor, PIL.Image.Image] = None,
        mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
        mask: Union[torch.FloatTensor, PIL.Image.Image] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        # callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        # callback_steps: int = 1,
    ):
        if image is None:
            raise ValueError("`image` input cannot be undefined.")

        if mask_image is None:
            raise ValueError("`mask_image` input cannot be undefined.")

        device = self._execution_device
        # 2. Encode input prompt
        # if type(prompt) == "string":
        #     ttf_imgs = draw_text(image[0].shape[:2][::-1], prompt, location)

        prompt_images = self.tokenizer(images=prompt, return_tensors="pt").pixel_values.to(device)
        ocr_feature = self.text_encoder(prompt_images)
        prompt_embeds = ocr_feature.last_hidden_state.detach()
        
        # 3. Define call parameters
        batch_size = prompt_embeds.shape[0]

        # 4. Preprocess mask and image
        
        # 5. set timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 6. Prepare latent variables
        vae = self.vae
        latents = vae.encode(image).latent_dist.sample()
        latents = latents * vae.config.scaling_factor
        noise = torch.randn_like(latents)

        # 7. Prepare mask latent variables
        # Rex: prepare mask && mask latent as input of UNET  
        vae_scale_factor = self.vae_scale_factor     
        width, height, *_ = mask.size()[::-1]         
        mask = torch.nn.functional.interpolate(
            mask, size=[width // vae_scale_factor, height // vae_scale_factor, *_][:-2][::-1]
        )
        
        masked_image_latents = vae.encode(mask_image).latent_dist.sample()
        masked_image_latents = masked_image_latents * vae.config.scaling_factor

        shape = (1, vae.config.latent_channels, height // vae_scale_factor, width // vae_scale_factor)

        latents = randn_tensor(shape, generator=torch.manual_seed(20), device=device,)
        
        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        
        # 10. Denoising loop
        # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = latents

                # concat latents, mask, masked_image_latents in the channel dimension
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

                # predict the noise residual
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents).prev_sample

                progress_bar.update()
                
        # 11. Post-processing
        pred_latents = 1 / vae.config.scaling_factor * latents
        image_vae = vae.decode(pred_latents).sample

        # 13. Convert to PIL
        image = (image_vae / 2 + 0.5) * 255.0
        image = image.cpu().permute(0, 2, 3, 1).float().detach().numpy()
        return image, image_vae

@chenhaoxing chenhaoxing mentioned this issue Jan 29, 2024
@prefixRAINSTARsuffix
Copy link
Author

@chenhaoxing 谢谢你的代码。

你在文中提到 “After dimension adjustment through a convolution layer, the feature vector 𝑧ˆ𝑡 = Conv(𝑧′𝑡) is fed into the UNet”,但是我在训练代码和推理代码里都没有看到这个卷积层,而是直接把 [latent_model_input, mask, masked_image_latents] 送入了 UNet,这样不会有维度不匹配问题吗。

我尝试去复现这一部分,但是生成的图片质量很差,请问具体是怎么做的呢。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants