Skip to content

Commit

Permalink
feat: add parameters linear schedule, uniform distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Jan 20, 2023
1 parent 7517b9f commit a34014f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
22 changes: 15 additions & 7 deletions audio_diffusion_pytorch/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ def __call__(self, num_samples: int, device: torch.device):


class UniformDistribution(Distribution):
def __init__(self, vmin: float = 0.0, vmax: float = 1.0):
super().__init__()
self.vmin, self.vmax = vmin, vmax

def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
return torch.rand(num_samples, device=device)
vmax, vmin = self.vmax, self.vmin
return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin


""" Diffusion Methods """
Expand Down Expand Up @@ -132,8 +137,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor:


class LinearSchedule(Schedule):
def __init__(self, start: float = 1.0, end: float = 0.0):
super().__init__()
self.start, self.end = start, end

def forward(self, num_steps: int, device: Any) -> Tensor:
return torch.linspace(1.0, 0.0, num_steps, device=device)
return torch.linspace(self.start, self.end, num_steps, device=device)


""" Samplers """
Expand All @@ -158,14 +167,13 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
return alpha, beta

def forward( # type: ignore
self, noise: Tensor, num_steps: int, show_progress: bool = False, **kwargs
self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
) -> Tensor:
b = noise.shape[0]
sigmas = self.schedule(num_steps + 1, device=noise.device)
b = x_noisy.shape[0]
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
sigmas = repeat(sigmas, "i -> i b", b=b)
sigmas_batch = extend_dim(sigmas, dim=noise.ndim + 1)
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
alphas, betas = self.get_alpha_beta(sigmas_batch)
x_noisy = noise * sigmas_batch[0]
progress_bar = tqdm(range(num_steps), disable=not show_progress)

for i in progress_bar:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.1.0",
version="0.1.1",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit a34014f

Please sign in to comment.