Skip to content

Commit

Permalink
misc improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Aug 25, 2021
1 parent a36fc09 commit 619ed00
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, config, layer_id):
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))

self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))

self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
Expand All @@ -110,15 +110,15 @@ def forward(self, x):
self.mask = self.mask[:T, :T]
w = w.masked_fill(self.mask == 0, 0)

x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
if hasattr(self, 'tiny_att'):
tiny_att = self.tiny_att(x, self.mask)

k = self.key(x)
v = self.value(x)
r = self.receptance(x)

k = torch.clamp(k, max=30) # clamp extreme values. e^30 = 10^13
k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)

Expand All @@ -138,7 +138,7 @@ class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))

hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
self.key = nn.Linear(config.n_embd, hidden_sz)
Expand All @@ -152,7 +152,7 @@ def __init__(self, config, layer_id):
def forward(self, x):
B, T, C = x.size()

x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(self, config, layer_id, time_shift = False):
self.head_size = config.n_attn // config.n_head

if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))

self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
Expand All @@ -252,7 +252,7 @@ def forward(self, x):
B, T, C = x.size()

if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)

q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
Expand Down Expand Up @@ -281,7 +281,7 @@ def __init__(self, config, layer_id, time_shift = False):
self.layer_id = layer_id

if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))

hidden_sz = 3 * config.n_ffn
self.key = nn.Linear(config.n_embd, hidden_sz)
Expand All @@ -291,7 +291,7 @@ def __init__(self, config, layer_id, time_shift = False):
def forward(self, x):
B, T, C = x.size()
if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)

k = self.key(x)
v = self.value(x)
Expand All @@ -317,7 +317,7 @@ def __init__(self, config, layer_id):
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))

self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
Expand All @@ -338,7 +338,7 @@ def forward(self, x):
w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]

x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-shift mixing
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) # time-shift mixing
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
Expand Down

0 comments on commit 619ed00

Please sign in to comment.