Skip to content

Commit

Permalink
[Based|LinearAttn] Fix arg bugs (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Jul 1, 2024
1 parent 98c176e commit 4351b6f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 41 deletions.
27 changes: 3 additions & 24 deletions fla/layers/based.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@


class BasedLinearAttention(nn.Module):

def __init__(
self,
hidden_size: int,
l_max: int = 2048,
feature_dim: int = 16,
num_key_value_heads: int = 12,
num_heads: int = 12,
Expand All @@ -28,12 +28,9 @@ def __init__(
mode: str = "parallel",
):
super().__init__()
self.hidden_size
self.l_max = l_max
self.mode = mode
assert self.mode in ["fused_chunk", "parallel", 'chunk']

# linear attention
self.hidden_size = hidden_size
self.mode = mode
self.feature_name = feature_name
self.feature_dim = feature_dim
self.num_key_value_heads = num_key_value_heads
Expand Down Expand Up @@ -106,21 +103,3 @@ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor =
y = self.o_proj(y.to(hidden_states.dtype))
y = self.dropout(y)
return y.to(hidden_states.dtype)


if __name__ == '__main__':
batch = 4
seq_len = 1024
hidden_size = 1024
dtype = torch.float32
x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
model = BasedLinearAttention(hidden_size, mode='chunk').to(dtype).cuda()
y = model(x)
y.backward(dy, retain_graph=True)
x_grad, x.grad = x.grad, None
y2 = model.forward_reference(x)
y2.backward(dy)
assert y.allclose(y2, 0, 1e-4), breakpoint()
assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
print("Pass")
32 changes: 15 additions & 17 deletions fla/layers/linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,9 @@ def __init__(
do_feature_map_norm: bool = False,
elementwise_affine: bool = True,
norm_eps: float = 1e-5,
**kwargs,
**kwargs
):
super().__init__()
assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp',
'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`."

assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`."

self.hidden_size = hidden_size
self.mode = mode
Expand Down Expand Up @@ -96,7 +92,7 @@ def elu(x):
self.feature_map_q = nn.Identity()
self.feature_map_k = nn.Identity()
else:
raise NotImplementedError
raise NotImplementedError(f"Not supported feature map `{feature_map}`.")

self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
Expand All @@ -107,7 +103,7 @@ def elu(x):
elif output_norm == 'identity':
self.norm = nn.Identity()
else:
raise NotImplementedError
raise NotImplementedError(f"Not supported output norm `{output_norm}`.")

self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)

Expand All @@ -127,26 +123,28 @@ def _initialize_weights(self, module: nn.Module):

def forward(self, x):
mode = self.mode
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = self.feature_map_q(q)
k = self.feature_map_k(k)
if self.norm_q:
q = q / (q.sum(-1, keepdim=True) + 1e-4)
if self.norm_k:
k = k / (k.sum(-1, keepdim=True) + 1e-4)

q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
if self.num_kv_groups > 1:
k, v = (repeat(x, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
else:
k, v = (rearrange(x, 'b n (h d) -> b h n d', h=self.num_kv_heads) for x in (k, v))
if self.norm_q:
q = q / (q.sum(-1, True) + 1e-4)
if self.norm_k:
k = k / (k.sum(-1, True) + 1e-4)

if mode == 'chunk':
o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
o, final_state = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
elif mode == 'fused_chunk':
o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
o, final_state = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
elif mode == 'fused_recurrent':
o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
o, final_state = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
else:
raise NotImplementedError
o = self.norm(o)
Expand Down

0 comments on commit 4351b6f

Please sign in to comment.