Skip to content

saran-gangster/IKT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

IKT

100x parameter efficiency + LONG CTX (click bait)

from IKT import InfiniKANTransformer
 # Model parameters
num_tokens = 10000
dim = 512
depth = 12
heads = 8
dim_head = 64
attn_dropout = 0.1
ff_dropout = 0.1
use_mem_delta_rule = True
kan_implementation = "EFFICIENT_KAN"

# Initialize the model
model = InfiniKANTransformer(
    num_tokens=num_tokens,
    dim=dim,
    depth=depth,
    heads=heads,
    dim_head=dim_head,
    attn_dropout=attn_dropout,
    ff_dropout=ff_dropout,
    use_mem_delta_rule=use_mem_delta_rule,
    kan_implementation=kan_implementation,
)

x = torch.randint(0, num_tokens, (1, 512))

# Forward pass with memory retrieval
logits, memories = model(x, return_new_memories=True) # Set False if you don't want memory

# Output shapes
print(f"Logits shape: {logits.shape}")  # (1, 512, 10000)
print(f"Number of layers with memories: {len(memories)}")  # 12 (number of layers)
print(f"Memory shape (KV matrix): {memories[0][0].shape}")  # (1, 8, 64, 64) (batch, heads, dim_head, dim_head)
print(f"Memory shape (Normalization vector): {memories[0][1].shape}")  # (1, 8, 64) (batch, heads, dim_head)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages