diff --git a/fla/layers/based.py b/fla/layers/based.py index bed0c161d..634f37001 100644 --- a/fla/layers/based.py +++ b/fla/layers/based.py @@ -15,10 +15,10 @@ class BasedLinearAttention(nn.Module): + def __init__( self, hidden_size: int, - l_max: int = 2048, feature_dim: int = 16, num_key_value_heads: int = 12, num_heads: int = 12, @@ -28,12 +28,9 @@ def __init__( mode: str = "parallel", ): super().__init__() - self.hidden_size - self.l_max = l_max - self.mode = mode - assert self.mode in ["fused_chunk", "parallel", 'chunk'] - # linear attention + self.hidden_size = hidden_size + self.mode = mode self.feature_name = feature_name self.feature_dim = feature_dim self.num_key_value_heads = num_key_value_heads @@ -106,21 +103,3 @@ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = y = self.o_proj(y.to(hidden_states.dtype)) y = self.dropout(y) return y.to(hidden_states.dtype) - - -if __name__ == '__main__': - batch = 4 - seq_len = 1024 - hidden_size = 1024 - dtype = torch.float32 - x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True) - dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda() - model = BasedLinearAttention(hidden_size, mode='chunk').to(dtype).cuda() - y = model(x) - y.backward(dy, retain_graph=True) - x_grad, x.grad = x.grad, None - y2 = model.forward_reference(x) - y2.backward(dy) - assert y.allclose(y2, 0, 1e-4), breakpoint() - assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() - print("Pass") diff --git a/fla/layers/linear_attn.py b/fla/layers/linear_attn.py index 8fab71b62..b8ec72f9d 100644 --- a/fla/layers/linear_attn.py +++ b/fla/layers/linear_attn.py @@ -31,13 +31,9 @@ def __init__( do_feature_map_norm: bool = False, elementwise_affine: bool = True, norm_eps: float = 1e-5, - **kwargs, + **kwargs ): super().__init__() - assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp', - 'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`." - - assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`." self.hidden_size = hidden_size self.mode = mode @@ -96,7 +92,7 @@ def elu(x): self.feature_map_q = nn.Identity() self.feature_map_k = nn.Identity() else: - raise NotImplementedError + raise NotImplementedError(f"Not supported feature map `{feature_map}`.") self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) @@ -107,7 +103,7 @@ def elu(x): elif output_norm == 'identity': self.norm = nn.Identity() else: - raise NotImplementedError + raise NotImplementedError(f"Not supported output norm `{output_norm}`.") self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) @@ -127,26 +123,28 @@ def _initialize_weights(self, module: nn.Module): def forward(self, x): mode = self.mode - q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) - k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) - v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) q = self.feature_map_q(q) k = self.feature_map_k(k) - if self.norm_q: - q = q / (q.sum(-1, keepdim=True) + 1e-4) - if self.norm_k: - k = k / (k.sum(-1, keepdim=True) + 1e-4) + + q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) if self.num_kv_groups > 1: k, v = (repeat(x, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v)) else: k, v = (rearrange(x, 'b n (h d) -> b h n d', h=self.num_kv_heads) for x in (k, v)) + if self.norm_q: + q = q / (q.sum(-1, True) + 1e-4) + if self.norm_k: + k = k / (k.sum(-1, True) + 1e-4) if mode == 'chunk': - o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) + o, final_state = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) elif mode == 'fused_chunk': - o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) + o, final_state = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) elif mode == 'fused_recurrent': - o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm) + o, final_state = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm) else: raise NotImplementedError o = self.norm(o)