From a268cd2e40351ee31c30c5f8a5d1266d35b41829 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Mon, 28 Nov 2022 11:42:13 +0000 Subject: [PATCH] better tinyAtt --- RWKV-v4neo/src/model.py | 20 +++++++++++--------- RWKV-v4neo/train.py | 1 - 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 0b961d65..38906b47 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -234,10 +234,11 @@ def __init__(self, args, layer_id): self.ffn = RWKV_ChannelMix(args, layer_id) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: - self.head_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) - self.head_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) - self.head_v = nn.Linear(args.n_embd, args.n_embd, bias=False) - self.register_buffer("head_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + self.tiny_ln = nn.LayerNorm(args.n_embd) + self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) + self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) def forward(self, x, x_emb=None): args = self.args @@ -255,11 +256,12 @@ def forward(self, x, x_emb=None): x = x + self.ffn(self.ln2(x)) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: - q = self.head_q(x)[:, :T, :] - k = self.head_k(x)[:, :T, :] - c = (q @ k.transpose(-2, -1)) * (1.0 / args.tiny_att_downscale) - c = c.masked_fill(self.head_mask[:T, :T] == 0, 0) - x = x + c @ self.head_v(x_emb) + xx = self.tiny_ln(x) + q = self.tiny_q(xx)[:, :T, :] + k = self.tiny_k(xx)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) + c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) + x = x + c @ self.tiny_v(x_emb) return x diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 580546ef..1adf6972 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -70,7 +70,6 @@ parser.add_argument("--head_qk", default=0, type=int) # my headQK trick parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer - parser.add_argument("--tiny_att_downscale", default=0, type=float) parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_final", default=1e-5, type=float)