Skip to content

Commit

Permalink
feat: diffMAE wrong when use_complex
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Nov 23, 2022
1 parent fd0b101 commit 5144b72
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
23 changes: 18 additions & 5 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def __init__(
encoder_multipliers: Sequence[int],
diffusion_type: str,
stft_num_fft: int,
stft_hop_length: int,
stft_use_complex: bool,
stft_window_length: Optional[int] = None,
encoder_patch_size: int = 1,
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
bottleneck_channels: Optional[int] = None,
Expand All @@ -209,6 +212,7 @@ def __init__(

encoder_kwargs, kwargs = groupby("encoder_", kwargs)
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
stft_kwargs, kwargs = groupby("stft_", kwargs)

# Compute context channels
context_channels = [0] * encoder_inject_depth
Expand All @@ -218,17 +222,26 @@ def __init__(
context_channels += [encoder_channels * encoder_multipliers[-1]]

self.spectrogram_channels = stft_num_fft // 2 + 1
self.stft_hop_length = stft_hop_length

self.encoder_stft = STFT(
num_fft=stft_num_fft,
hop_length=stft_hop_length,
window_length=stft_window_length,
use_complex=False, # Magnitude encoding
)

self.unet = UNet1d(
in_channels=in_channels,
stft_num_fft=stft_num_fft,
context_channels=context_channels,
use_stft=True,
stft_use_complex=stft_use_complex,
stft_num_fft=stft_num_fft,
stft_hop_length=stft_hop_length,
stft_window_length=stft_window_length,
**kwargs,
)

self.stft = self.unet.stft

self.diffusion = XDiffusion(
type=diffusion_type, net=self.unet, **diffusion_kwargs
)
Expand All @@ -251,7 +264,7 @@ def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
# Extract magnitude and encode
magnitude, _ = self.stft.encode(x)
magnitude, _ = self.encoder_stft.encode(x)
magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t")
latent, info = self.encoder(magnitude_flat, with_info=True)
# Apply bottlenecks if present
Expand All @@ -270,7 +283,7 @@ def forward( # type: ignore
def decode(self, latent: Tensor, **kwargs) -> Tensor:
b = latent.shape[0]
length = closest_power_2(
self.stft.hop_length * latent.shape[2] * self.encoder_downsample_factor
self.stft_hop_length * latent.shape[2] * self.encoder_downsample_factor
)
# Compute noise by inferring shape from latent length
noise = torch.randn(b, self.in_channels, length, device=latent.device)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.89",
version="0.0.90",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 5144b72

Please sign in to comment.