Skip to content

Implementation of CALM from the paper "LLM Augmented LLMs: Expanding Capabilities through Composition", out of Google Deepmind

License

Notifications You must be signed in to change notification settings

lucidrains/CALM-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

68 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CALM - Pytorch

Implementation of CALM from the paper LLM Augmented LLMs: Expanding Capabilities through Composition, out of Google Deepmind

Can support any number of augmentation LLMs

Install

$ pip install CALM-pytorch

Appreciation

Usage

ex. with x-transformers

import torch
from x_transformers import TransformerWrapper, Decoder

augment_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 2,
        heads = 8
    )
)

# import CALM wrapper

from CALM_pytorch import CALM, AugmentParams

calm = CALM(
    anchor_llm,
    augment_llms = AugmentParams(
        model = augment_llm,
        connect_every_num_layers = 4
    )
)

# mock input

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()
prompt = torch.randint(0, 20000, (1, 256))

# forward for finetuning loss

loss = calm(
    seq,
    mask = mask,
    prompt = prompt
)

loss.backward()

# after much training, prompt the composed model

generated = calm.generate(
    prompt = seq[:, :1],
    seq_len = 1024
)

To use a handy trainer class using 🤗 Accelerate, just import FineTuner and use as follows

trainer = FineTuner(
    calm = calm,
    dataset = dataset,   # returns a dictionary of input kwargs to calm - dict(seq: Tensor, mask: Tensor, prompt: Tensor). it can also return a Tuple, in which data_kwargs needs to be set to the correct ordered value of kwarg names
    batch_size = 16,
    num_train_steps = 10000,
    learning_rate = 3e-4,
    weight_decay = 1e-2,
    warmup_steps = 1000,
    checkpoint_every = 1000
)

trainer()

# checkpoints of the cross attention parameters will be saved to ./checkpoints every 1000 steps

To explore multiple augmentation LLMs, simply pass in a list for augment_llm

ex.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llm = [AugmentParams(augment_llm1), AugmentParams(augment_llm2)] # pass in a list of AugmentParams wrapping model and other hparams specific to that transformer
)

Say you want to explore different types of connectivity between anchor and augmentation model(s), just pass in the connections as a tuple of tuple integer pairs, specifying the anchor to augment layer number.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
        AugmentParams(
            model = augment_llm1,
            connections = (
                (1, 12),  # 1st layer of augment llm1 attended to by 12th layer of anchor llm
                (2, 12),
                (3, 12),
                (4, 12),
            ),
        ),
        AugmentParams(
            model = augment_llm2,
            connections = (
                (6, 1),   # 6th layer of augment llm2 attended to by 1st layer of anchor llm
                (6, 2),
                (12, 12),
            )
        )
    )
)

CALM setup with 2 specialized augmentation LLMs + a vision transformer

import torch

# pip install vit-pytorch x-transformers

from vit_pytorch.vit import ViT, Attention
from x_transformers import TransformerWrapper, Encoder, Decoder

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

augment_llm1 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

augment_llm2 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 256,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# calm

from CALM_pytorch import CALM, AugmentParams, FineTuner

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
        AugmentParams(
            model = augment_llm1,
            mask_kwarg = 'mask'
        ),
        AugmentParams(
            model = augment_llm2,
            mask_kwarg = 'mask'
        ),
        AugmentParams(
            model = vit,
            input_shape = (3, 256, 256),
            hidden_position = 'input',
            extract_blocks_fn = lambda vit: [m for m in vit.modules() if isinstance(m, Attention)]
        )
    ),
    attn_kwargs = dict(
        linear_project_context = True,
        pre_rmsnorm = True,
        flash = True
    )
)

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()

prompt = (
    torch.randint(0, 20000, (1, 256)),
    torch.randint(0, 20000, (1, 256)),
    torch.randn(1, 3, 256, 256)
)

loss = calm(
    seq,
    mask = mask,
    prompt = prompt
)

loss.backward()

Todo

  • figure out how to correctly mask augment llm tokens

  • auto-derive model dimensions with dummy input

  • take care of finetuning training logic

  • show example of manual definitions of custom connectivity between 2+ attention networks

  • if anchor and augment transformer block modules are directly passed in (without extraction fn), run a dummy input through both networks and order them correctly using hooks

  • fix example for x-transformers, as in x-transformers, depth is actually depth x 2, taking hiddens from after attention and ff

  • when finely specifying hidden positions, make sure to reorder it if the transformer blocks themselves were passed in and not ordered to begin with

  • extend to a list of augmentation llms

    • full connectivity customization
    • custom number of augmentation layers per augmetation llm
    • make simple vit work
      • refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}} - use dataclasses
      • show example
  • take care of caching the augment hiddens when sampling. forget about anchor kv cache for now

    • logic for not releasing the saved output from recorder, for inference
    • managing cross attention block state for popping the saved output from the recorder
    • move the augmentation forwards into one shared method, and craft out sampling method for anchor
  • able to wire up with just module names

  • show an example with giving the LLM ability to hear as well, using hubert or wav2vec wrappers

  • handle a wrapper or function that takes in the sequence and prompt length, and auto derives the inputs to CALM

  • add an option for self attention path way with memory tokens attending to hidden states of all augmentation llms, akin to what was done with Zorro

Citations

@inproceedings{Bansal2024LLMAL,
  title   = {LLM Augmented LLMs: Expanding Capabilities through Composition},
  author  = {Rachit Bansal and Bidisha Samanta and Siddharth Dalmia and Nitish Gupta and Shikhar Vashishth and Sriram Ganapathy and Abhishek Bapna and Prateek Jain and Partha Pratim Talukdar},
  year    = {2024},
  url     = {https://api.semanticscholar.org/CorpusID:266755751}
}

About

Implementation of CALM from the paper "LLM Augmented LLMs: Expanding Capabilities through Composition", out of Google Deepmind

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages