Skip to content

Commit

Permalink
Merge pull request #9 from Andrei-Aksionov/feature/shape_description_fix
Browse files Browse the repository at this point in the history
Clarify shape descriptions inside forward method
  • Loading branch information
karpathy authored Feb 7, 2023
2 parents d38c865 + 4c8e902 commit 5220142
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,19 @@ def __init__(self, head_size):
self.dropout = nn.Dropout(dropout)

def forward(self, x):
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
B,T,C = x.shape
k = self.key(x) # (B,T,C)
q = self.query(x) # (B,T,C)
k = self.key(x) # (B,T,hs)
q = self.query(x) # (B,T,hs)
# compute attention scores ("affinities")
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
# perform the weighted aggregation of the values
v = self.value(x) # (B,T,C)
out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
v = self.value(x) # (B,T,hs)
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
return out

class MultiHeadAttention(nn.Module):
Expand All @@ -93,7 +95,7 @@ class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
Expand Down

0 comments on commit 5220142

Please sign in to comment.