Skip to content

Commit

Permalink
feat: add clamp option
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 24, 2022
1 parent 24ff00f commit b1b859e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
4 changes: 3 additions & 1 deletion audio_diffusion_pytorch/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,14 @@ def __init__(
sampler: Sampler,
sigma_schedule: Schedule,
num_steps: Optional[int] = None,
clamp: bool = True,
):
super().__init__()
self.denoise_fn = diffusion.denoise_fn
self.sampler = sampler
self.sigma_schedule = sigma_schedule
self.num_steps = num_steps
self.clamp = clamp

# Check sampler is compatible with diffusion type
sampler_class = sampler.__class__.__name__
Expand All @@ -581,7 +583,7 @@ def forward(
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
# Sample using sampler
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
x = x.clamp(-1.0, 1.0)
x = x.clamp(-1.0, 1.0) if self.clamp else x
return x


Expand Down
7 changes: 3 additions & 4 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def sample(
num_steps: int,
sigma_schedule: Schedule,
sampler: Sampler,
clamp: bool,
**kwargs,
) -> Tensor:
diffusion_sampler = DiffusionSampler(
diffusion=self.diffusion,
sampler=sampler,
sigma_schedule=sigma_schedule,
num_steps=num_steps,
clamp=clamp,
)
return diffusion_sampler(noise, **kwargs)

Expand Down Expand Up @@ -251,10 +253,7 @@ def get_default_model_kwargs():


def get_default_sampling_kwargs():
return dict(
sigma_schedule=LinearSchedule(),
sampler=VSampler(),
)
return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)


class AudioDiffusionModel(Model1d):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.74",
version="0.0.75",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit b1b859e

Please sign in to comment.