Skip to content

Commit

Permalink
prepare for pg19 training and inference
Browse files Browse the repository at this point in the history
prepare for pg19 train / inference to be completed next week
  • Loading branch information
lucidrains committed May 5, 2024
1 parent d1b053c commit ca05f50
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
33 changes: 23 additions & 10 deletions infini_transformer_pytorch/infini_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Tuple, List, Optional
from __future__ import annotations
from typing import Tuple, List

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -87,7 +88,8 @@ def __init__(
def forward(
self,
x,
past_memories: Optional[Memories] = None,
past_memories: Memories | None = None,
return_new_memories = False,
eps = 1e-10
) -> Tuple[Tensor, Memories]:
"""
Expand Down Expand Up @@ -164,6 +166,17 @@ def retrieve_from_kv_memories(t, past_memories: Memories):

out = out * gates + mem_out * (1. - gates) # eq (6) - figure 3 shows how heads emergently specialize to look either at the present, past, or a bit of both

# merge heads and combine

out = self.merge_heads(out)
out = self.to_out(out)

# if new memories are not needed, early return
# at inference time, kv cache up to segment length and then compress memories into kv

if not return_new_memories:
return out, past_memories

# create the next memories

if exists(past_memories) and self.use_mem_delta_rule:
Expand All @@ -183,11 +196,6 @@ def retrieve_from_kv_memories(t, past_memories: Memories):

new_memories = (new_memories_kv, new_memories_norm)

# merge heads and combine

out = self.merge_heads(out)
out = self.to_out(out)

return out, new_memories

# main class
Expand Down Expand Up @@ -234,7 +242,7 @@ def __init__(
def forward(
self,
x,
past_memories: Optional[List[Memories]] = None,
past_memories: List[Memories] | None = None,
return_memories = False,
detach_memories = False
):
Expand All @@ -245,7 +253,12 @@ def forward(

for attn, ff in self.layers:
past_memories = next(past_memories_iter, None)
attn_out, layer_new_memories = attn(x, past_memories = past_memories)

attn_out, layer_new_memories = attn(
x,
past_memories = past_memories,
return_new_memories = return_memories
)

x = attn_out + x
x = ff(x) + x
Expand All @@ -257,7 +270,7 @@ def forward(
logits = self.to_logits(embed)

if not return_memories:
return logits
return logits, past_memories

if detach_memories:
detach_memories_(new_memories)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'infini-transformer-pytorch',
packages = find_packages(exclude = []),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'Infini-Transformer in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit ca05f50

Please sign in to comment.