diff --git a/audio_diffusion_pytorch/diffusion.py b/audio_diffusion_pytorch/diffusion.py index ea9154f..63e3115 100644 --- a/audio_diffusion_pytorch/diffusion.py +++ b/audio_diffusion_pytorch/diffusion.py @@ -19,8 +19,13 @@ def __call__(self, num_samples: int, device: torch.device): class UniformDistribution(Distribution): + def __init__(self, vmin: float = 0.0, vmax: float = 1.0): + super().__init__() + self.vmin, self.vmax = vmin, vmax + def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): - return torch.rand(num_samples, device=device) + vmax, vmin = self.vmax, self.vmin + return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin """ Diffusion Methods """ @@ -132,8 +137,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor: class LinearSchedule(Schedule): + def __init__(self, start: float = 1.0, end: float = 0.0): + super().__init__() + self.start, self.end = start, end + def forward(self, num_steps: int, device: Any) -> Tensor: - return torch.linspace(1.0, 0.0, num_steps, device=device) + return torch.linspace(self.start, self.end, num_steps, device=device) """ Samplers """ @@ -158,14 +167,13 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: return alpha, beta def forward( # type: ignore - self, noise: Tensor, num_steps: int, show_progress: bool = False, **kwargs + self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs ) -> Tensor: - b = noise.shape[0] - sigmas = self.schedule(num_steps + 1, device=noise.device) + b = x_noisy.shape[0] + sigmas = self.schedule(num_steps + 1, device=x_noisy.device) sigmas = repeat(sigmas, "i -> i b", b=b) - sigmas_batch = extend_dim(sigmas, dim=noise.ndim + 1) + sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) alphas, betas = self.get_alpha_beta(sigmas_batch) - x_noisy = noise * sigmas_batch[0] progress_bar = tqdm(range(num_steps), disable=not show_progress) for i in progress_bar: diff --git a/setup.py b/setup.py index 23b7b19..3c98e62 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.1.0", + version="0.1.1", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",