You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This code addresses the issue of high memory consumption when loading weights into a model.
In the traditional approach, two sets of weights exist simultaneously:
Randomly initialized weights when the model is created.
Pretrained weights that need to be loaded into the model.
model=GPT(...) # model with random weightsweights=torch.load(...) # pretrained weightsmodel.load_state_dict(weight)
To mitigate this, the process is split into multiple steps:
Model creation on a meta device:
Using fabric.init_module, the model is created on a meta device. On this device, memory usage is minimal because the weight matrices remain "empty" until explicitly materialized. (Refer to the meta device documentation).
Target device setup:
The fabric.setup(model) call specifies the target device (e.g., GPU) where the model will be placed.
Loading pretrained weights:
Finally, load_checkpoint(fabric, model, checkpoint_path) loads the pretrained weights into the model, materializing it on the target device with minimal memory overhead.
but when loading checkpoint before model = fabric.setup(model), get loss
This happens because the model is materialized with random weights, as load_checkpoint was called before fabric.setup for the model on meta device. load_checkpoint function uses lazy_load from PyTorch that cannot do materialization.
So, when you run fabric.init_module (placing on meta device) and then load_checkpoint, nothing really happens here, the model stays on meta device. And when the model is materialized on the target device, weights values are totally random.
When you commented out fabric.init_module the model was created on a CPU with random weights, then load_checkpoint loaded pretrained weights into it and fabric.setup moved the model to the target device.
The loss value provides a hint.
With a vocabulary size of approximately 151k (for Qwen2.5-1.5B) and randomly initialized weights, the expected loss is around 12.
Init model with
fabric.init_module(True)
and load checkpoint aftermodel = fabric.setup(model)
, the training loss is normalbut when loading checkpoint before
model = fabric.setup(model)
, get lossAnother phenomenon is that, if not using
fabric.init_module()
, I can get normal loss when loading checkpoint beforefabric.setup(model)
,So how to load hf models converted by
litgpt.scripts.convert_hf_checkpoint
in a correct way?The text was updated successfully, but these errors were encountered: