Skip to content

Commit

Permalink
feat: add magnitude, phase info
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 26, 2022
1 parent fc1fb9b commit c458fc2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,9 +1777,14 @@ def encode(
log_magnitude = rearrange(torch.log(magnitude), "b c f t -> b (c f) t")
return super().encode(log_magnitude, with_info)

def decode(self, z: Tensor) -> Tensor:
def decode( # type: ignore
self, z: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
f = self.frequency_channels
stft = super().decode(z)
stft = rearrange(stft, "b (c f i) t -> b (c i) f t", i=2, f=f)
log_magnitude, phase = stft.chunk(chunks=2, dim=1)
return self.stft.decode(magnitude=torch.exp(log_magnitude), phase=phase)
magnitude = torch.exp(log_magnitude)
wave = self.stft.decode(magnitude, phase)
info = dict(magnitude=magnitude, phase=phase)
return (wave, info) if with_info else wave
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.77",
version="0.0.78",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit c458fc2

Please sign in to comment.