Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading checkpoint before fabric.setup(model) gets abnormal loss when using fabric.init_module() #1868

Open
kobenaxie opened this issue Dec 10, 2024 · 1 comment
Labels
question Further information is requested

Comments

@kobenaxie
Copy link

kobenaxie commented Dec 10, 2024

from pathlib import Path

import torch
import lightning as L
from lightning.fabric.strategies import FSDPStrategy

from litgpt.args import TrainArgs
from litgpt.config import Config
from litgpt.model import GPT, Block
from litgpt.data import Alpaca2k
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
    chunked_cross_entropy,
    load_checkpoint,
    num_parameters,
    get_default_supported_precision,
)


def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
    # linear warmup followed by cosine annealing
    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
    return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def main(
    checkpoint_dir: Path,
    devices: int = 8,
    num_nodes: int = 1,
    precision: str = "bf16-true",
    seed: int = 1337,
) -> None:
    torch.set_float32_matmul_precision("high")

    train_args = TrainArgs(
        save_interval = 1000,
        log_interval = 1,
        global_batch_size = 64,
        micro_batch_size = 4,
        lr_warmup_steps = 1000,
        epochs = 10,
        max_steps = 10000,
    )

    strategy = FSDPStrategy(
        auto_wrap_policy={Block},
        activation_checkpointing_policy={Block},
        state_dict_type="full",
        limit_all_gathers=True,
        cpu_offload=False,
    )
    
    fabric = L.Fabric(
        accelerator="cuda",
        devices=devices,
        num_nodes=num_nodes,
        strategy=strategy,
        precision=precision,
    )
    fabric.launch()
    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)
    
    dataset = Alpaca2k()
    tokenizer = Tokenizer(str(checkpoint_dir))
    dataset.connect(tokenizer, batch_size=train_args.micro_batch_size, max_seq_length=512)
    with fabric.rank_zero_first():
        dataset.prepare_data()
    dataset.setup()
    dataloader = dataset.train_dataloader()
    dataloader = fabric.setup_dataloaders(dataloader)

    checkpoint_path = str(checkpoint_dir / "lit_model.pth")
    config = Config.from_file(checkpoint_dir / "model_config.yaml")
    with fabric.init_module(empty_init=(fabric.world_size > 1)):
        model = GPT(config)
    fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
    # load_checkpoint(fabric, model, checkpoint_path)
    model = fabric.setup(model)
    load_checkpoint(fabric, model, checkpoint_path)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
    optimizer = fabric.setup_optimizers(optimizer)
    scheduler = get_lr_scheduler(optimizer, warmup_steps=train_args.lr_warmup_steps, max_steps=train_args.max_steps)

    model.train()
    for epoch in range(train_args.epochs):
        for step, batch in enumerate(dataloader, 1):
            input, target = batch["input_ids"], batch["labels"]
            logits = model(input)
            loss = chunked_cross_entropy(logits[..., :-1, :], target[..., 1:])
            fabric.backward(loss)

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            fabric.print(f"{step = } | loss train: {loss.detach().item()}")


if __name__ == "__main__":
    checkpoint_dir = Path("./Qwen2.5-1.5B/")

    main(checkpoint_dir)

Init model with fabric.init_module(True) and load checkpoint after model = fabric.setup(model), the training loss is normal

with fabric.init_module(empty_init=(fabric.world_size > 1)):
    model = GPT(config)
model = fabric.setup(model)
load_checkpoint(fabric, model, checkpoint_path

step = 1 | loss train: 0.8448048233985901
step = 2 | loss train: 1.3229767084121704
step = 3 | loss train: 1.2647839784622192
step = 4 | loss train: 1.287076711654663
step = 5 | loss train: 1.0357563495635986

but when loading checkpoint before model = fabric.setup(model), get loss

with fabric.init_module(empty_init=(fabric.world_size > 1)):
    model = GPT(config)
load_checkpoint(fabric, model, checkpoint_path
model = fabric.setup(model)

step = 1 | loss train: 12.027938842773438
step = 2 | loss train: 12.051375389099121
step = 3 | loss train: 12.112957954406738
step = 4 | loss train: 12.08558177947998
step = 5 | loss train: 12.089488983154297

Another phenomenon is that, if not using fabric.init_module(), I can get normal loss when loading checkpoint before fabric.setup(model),

# with fabric.init_module(empty_init=(fabric.world_size > 1)):
if True:
    model = GPT(config)
load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup(model)

step = 1 | loss train: 0.8447667956352234
step = 2 | loss train: 1.3229438066482544
step = 3 | loss train: 1.2663335800170898
step = 4 | loss train: 1.2902932167053223
step = 5 | loss train: 1.035811185836792

So how to load hf models converted by litgpt.scripts.convert_hf_checkpoint in a correct way?

@Andrei-Aksionov
Copy link
Collaborator

Andrei-Aksionov commented Dec 11, 2024

Hello @kobenaxie

This code addresses the issue of high memory consumption when loading weights into a model.

In the traditional approach, two sets of weights exist simultaneously:

  1. Randomly initialized weights when the model is created.
  2. Pretrained weights that need to be loaded into the model.
model = GPT(...) # model with random weights
weights = torch.load(...) # pretrained weights
model.load_state_dict(weight)

To mitigate this, the process is split into multiple steps:

  1. Model creation on a meta device:
    Using fabric.init_module, the model is created on a meta device. On this device, memory usage is minimal because the weight matrices remain "empty" until explicitly materialized. (Refer to the meta device documentation).

  2. Target device setup:
    The fabric.setup(model) call specifies the target device (e.g., GPU) where the model will be placed.

  3. Loading pretrained weights:
    Finally, load_checkpoint(fabric, model, checkpoint_path) loads the pretrained weights into the model, materializing it on the target device with minimal memory overhead.


but when loading checkpoint before model = fabric.setup(model), get loss

This happens because the model is materialized with random weights, as load_checkpoint was called before fabric.setup for the model on meta device.
load_checkpoint function uses lazy_load from PyTorch that cannot do materialization.

So, when you run fabric.init_module (placing on meta device) and then load_checkpoint, nothing really happens here, the model stays on meta device. And when the model is materialized on the target device, weights values are totally random.

When you commented out fabric.init_module the model was created on a CPU with random weights, then load_checkpoint loaded pretrained weights into it and fabric.setup moved the model to the target device.

The loss value provides a hint.
With a vocabulary size of approximately 151k (for Qwen2.5-1.5B) and randomly initialized weights, the expected loss is around 12.

import torch
import torch.nn.functional as F

batch_size = 32
vocab_size = 151_643
logits = torch.randn(batch_size, vocab_size)
targets = torch.randint(0, 2, (batch_size, vocab_size)).float()

loss_ce = F.cross_entropy(logits, targets.argmax(dim=1))
print(loss_ce)

>> tensor(12.1440)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants