Skip to content

Commit

Permalink
Add docstring for the AutoencoderKL's decode (huggingface#5242)
Browse files Browse the repository at this point in the history
* Add docstring for the AutoencoderKL's decode

huggingface#5230

* Follow the style guidelines in AutoencoderKL's decode

huggingface#5230

---------

Co-authored-by: stano <>
  • Loading branch information
freespirit committed Oct 2, 2023
1 parent b50d4f3 commit 3465edd
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,20 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod

@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
Expand Down

0 comments on commit 3465edd

Please sign in to comment.