From f7ee6b0c8b80d2cec14c12a39a088b2cb11149da Mon Sep 17 00:00:00 2001 From: stano <> Date: Fri, 29 Sep 2023 20:47:47 +0300 Subject: [PATCH 1/2] Add docstring for the AutoencoderKL's decode #5230 --- src/diffusers/models/autoencoder_kl.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 21c8f64fd916..dba2c4cf29bf 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -281,6 +281,20 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod @apply_forward_hook def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) From 7f2d5d3af71e7d8e42baa1d9c4992a7adf44ae01 Mon Sep 17 00:00:00 2001 From: stano <> Date: Mon, 2 Oct 2023 18:33:54 +0300 Subject: [PATCH 2/2] Follow the style guidelines in AutoencoderKL's decode #5230 --- src/diffusers/models/autoencoder_kl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index dba2c4cf29bf..7e3b925df714 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -283,17 +283,17 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. - + Args: z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - + Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. - + """ if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]