Skip to content

Commit

Permalink
feat: add option to use rel pos
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 21, 2022
1 parent 7f6151e commit e22b4e9
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 4 deletions.
1 change: 1 addition & 0 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def get_default_model_kwargs():
attention_heads=8,
attention_features=64,
attention_multiplier=2,
attention_use_rel_pos=False,
resnet_groups=8,
kernel_multiplier_downsample=2,
use_nearest_upsample=False,
Expand Down
127 changes: 124 additions & 3 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,55 @@ def __init__(
"""


class RelativePositionBias(nn.Module):
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.num_heads = num_heads
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)

@staticmethod
def _relative_position_bucket(
relative_position: Tensor, num_buckets: int, max_distance: int
):
num_buckets //= 2
ret = (relative_position >= 0).to(torch.long) * num_buckets
n = torch.abs(relative_position)

max_exact = num_buckets // 2
is_small = n < max_exact

val_if_large = (
max_exact
+ (
torch.log(n.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).long()
)
val_if_large = torch.min(
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
)

ret += torch.where(is_small, n, val_if_large)
return ret

def forward(self, num_queries: int, num_keys: int) -> Tensor:
i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")

relative_position_bucket = self._relative_position_bucket(
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
)

bias = self.relative_attention_bias(relative_position_bucket)
bias = rearrange(bias, "m n h -> 1 h m n")
return bias


def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
Expand All @@ -331,19 +380,33 @@ def __init__(
*,
head_features: int,
num_heads: int,
use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
):
super().__init__()
self.scale = head_features ** -0.5
self.num_heads = num_heads
self.use_rel_pos = use_rel_pos
mid_features = head_features * num_heads

if use_rel_pos:
assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
self.rel_pos = RelativePositionBias(
num_buckets=rel_pos_num_buckets,
max_distance=rel_pos_max_distance,
num_heads=num_heads,
)

self.to_out = nn.Linear(in_features=mid_features, out_features=features)

def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Split heads
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
# Compute similarity matrix
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
sim = einsum("... n d, ... m d -> ... n m", q, k)
sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
sim = sim * self.scale
# Get attention matrix with softmax
attn = sim.softmax(dim=-1)
# Compute values
Expand All @@ -360,6 +423,9 @@ def __init__(
head_features: int,
num_heads: int,
context_features: Optional[int] = None,
use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
):
super().__init__()
self.context_features = context_features
Expand All @@ -375,7 +441,12 @@ def __init__(
in_features=context_features, out_features=mid_features * 2, bias=False
)
self.attention = AttentionBase(
features, num_heads=num_heads, head_features=head_features
features,
num_heads=num_heads,
head_features=head_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)

def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
Expand All @@ -402,14 +473,22 @@ def __init__(
num_heads: int,
head_features: int,
multiplier: int,
use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None,
):
super().__init__()

self.use_cross_attention = exists(context_features) and context_features > 0

self.attention = Attention(
features=features, num_heads=num_heads, head_features=head_features
features=features,
num_heads=num_heads,
head_features=head_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)

if self.use_cross_attention:
Expand All @@ -418,6 +497,9 @@ def __init__(
num_heads=num_heads,
head_features=head_features,
context_features=context_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)

self.feed_forward = FeedForward(features=features, multiplier=multiplier)
Expand All @@ -443,6 +525,9 @@ def __init__(
num_heads: int,
head_features: int,
multiplier: int,
use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None,
):
super().__init__()
Expand All @@ -465,6 +550,9 @@ def __init__(
num_heads=num_heads,
multiplier=multiplier,
context_features=context_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -552,6 +640,9 @@ def __init__(
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
attention_use_rel_pos: Optional[bool] = None,
attention_rel_pos_max_distance: Optional[int] = None,
attention_rel_pos_num_buckets: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
Expand Down Expand Up @@ -588,6 +679,7 @@ def __init__(
exists(attention_heads)
and exists(attention_features)
and exists(attention_multiplier)
and exists(attention_use_rel_pos)
)
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
Expand All @@ -596,6 +688,9 @@ def __init__(
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
use_rel_pos=attention_use_rel_pos,
rel_pos_num_buckets=attention_rel_pos_num_buckets,
rel_pos_max_distance=attention_rel_pos_max_distance,
)

if self.use_extract:
Expand Down Expand Up @@ -659,6 +754,9 @@ def __init__(
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
attention_use_rel_pos: Optional[bool] = None,
attention_rel_pos_max_distance: Optional[int] = None,
attention_rel_pos_num_buckets: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
Expand Down Expand Up @@ -689,6 +787,7 @@ def __init__(
exists(attention_heads)
and exists(attention_features)
and exists(attention_multiplier)
and exists(attention_use_rel_pos)
)
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
Expand All @@ -697,6 +796,9 @@ def __init__(
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
use_rel_pos=attention_use_rel_pos,
rel_pos_num_buckets=attention_rel_pos_num_buckets,
rel_pos_max_distance=attention_rel_pos_max_distance,
)

self.upsample = Upsample1d(
Expand Down Expand Up @@ -756,6 +858,9 @@ def __init__(
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
attention_use_rel_pos: Optional[bool] = None,
attention_rel_pos_max_distance: Optional[int] = None,
attention_rel_pos_num_buckets: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
Expand All @@ -774,6 +879,7 @@ def __init__(
exists(attention_heads)
and exists(attention_features)
and exists(attention_multiplier)
and exists(attention_use_rel_pos)
)
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
Expand All @@ -782,6 +888,9 @@ def __init__(
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
use_rel_pos=attention_use_rel_pos,
rel_pos_num_buckets=attention_rel_pos_num_buckets,
rel_pos_max_distance=attention_rel_pos_max_distance,
)

self.post_block = ResnetBlock1d(
Expand Down Expand Up @@ -844,6 +953,9 @@ def __init__(
context_features: Optional[int] = None,
context_channels: Optional[Sequence[int]] = None,
context_embedding_features: Optional[int] = None,
attention_use_rel_pos: bool = False,
attention_rel_pos_max_distance: Optional[int] = None,
attention_rel_pos_num_buckets: Optional[int] = None,
):
super().__init__()
out_channels = default(out_channels, in_channels)
Expand Down Expand Up @@ -931,6 +1043,9 @@ def __init__(
attention_heads=attention_heads,
attention_features=attention_features,
attention_multiplier=attention_multiplier,
attention_use_rel_pos=attention_use_rel_pos,
attention_rel_pos_max_distance=attention_rel_pos_max_distance,
attention_rel_pos_num_buckets=attention_rel_pos_num_buckets,
)
for i in range(num_layers)
]
Expand All @@ -945,6 +1060,9 @@ def __init__(
attention_heads=attention_heads,
attention_features=attention_features,
attention_multiplier=attention_multiplier,
attention_use_rel_pos=attention_use_rel_pos,
attention_rel_pos_max_distance=attention_rel_pos_max_distance,
attention_rel_pos_num_buckets=attention_rel_pos_num_buckets,
)

self.upsamples = nn.ModuleList(
Expand All @@ -966,6 +1084,9 @@ def __init__(
attention_heads=attention_heads,
attention_features=attention_features,
attention_multiplier=attention_multiplier,
attention_use_rel_pos=attention_use_rel_pos,
attention_rel_pos_max_distance=attention_rel_pos_max_distance,
attention_rel_pos_num_buckets=attention_rel_pos_num_buckets,
)
for i in reversed(range(num_layers))
]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.69",
version="0.0.70",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit e22b4e9

Please sign in to comment.