Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with saving standalone gemma-2b-it after fine-tuning with LoRA on TPU v3-8 #29659

Closed
4 tasks done
PawKanarek opened this issue Mar 14, 2024 · 39 comments
Closed
4 tasks done

Comments

@PawKanarek
Copy link

System Info

- `transformers` version: 4.39.0.dev0
- Platform: Linux-5.13.0-1027-gcp-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.21.4
- Safetensors version: 0.4.2
- Accelerate version: 0.27.2
- Accelerate config:    - compute_environment: LOCAL_MACHINE
        - distributed_type: TPU
        - mixed_precision: no
        - use_cpu: False
        - debug: False
        - num_processes: 8
        - machine_rank: 0
        - num_machines: 1
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []
- PyTorch version (GPU?): 2.3.0.dev20240307 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: no, this is TPU
- Using distributed or parallel set-up in script?: yes
print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")

torch.version='2.3.0.dev20240307'
torch_xla.version='2.3.0+git46e2230'
peft.version='0.9.0'
trl.version='0.7.12.dev0'
Python 3.10.13

Who can help?

@ArthurZucker , @younesbelkada, @muellerzr, @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hello, I have a problem with training the gemma-2b-it model on Google TPU v3-8. My goal is to train it with the peft lora adapter, and then save it as a standalone model.

For merging base model with lora adapter I was following the guide: https://huggingface.co/docs/trl/main/en/use_model
Training code is based on this blog post: https://huggingface.co/blog/gemma-peft

The problem is that the training takes a while (for 300k rows in a data loader it might take even 8 hours) but after training the model seems… untrained. The interference output looks almost identical to the output of the base model.

Furthermore, when I check for the weights of the trained and original models then they appear to be identical.

I also consistently encounter the following error message, while loading saved model:

Some weights of the model checkpoint at output/merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 
(...)
'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Below is the minimal working code that trains and saves the model.

# Make sure to run the script with the following envs:
#   PJRT_DEVICE=TPU XLA_USE_SPMD=1
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer
print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")


device = xm.xla_device() # Set up TPU device.


def check_model_weights_equality(model1, model2):
    params1, params2 = model1.parameters(), model2.parameters() 
    sum1 = sum(p.numel() for p in params1)
    sum2 = sum(p.numel() for p in params2)
    
    if (sum1 != sum2):
        print(f"Number of parameters are different in {model1.__class__}:{sum1} and {model2.__class__}:{sum2} are different")
        return False
    
    for p1, p2 in zip(params1, params2):
        if not torch.equal(p1, p2):
            print(f"weights of {model1.__class__} and {model2.__class__} are different")
            return False
    
    print(f"models {model1.__class__} and {model2.__class__} are the same")
    return True

def train():
    tokenizer =  AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
    dataset = load_dataset("pawkanarek/poke_test", split="train")
    lora_config = LoraConfig(r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM")
    fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        tokenizer = tokenizer,
        args=TrainingArguments(
            per_device_train_batch_size=64,
            num_train_epochs=4, # small epochs for brevity, but the same is also with larger epochs
            output_dir="output/trained_model",
            optim="adafactor",
            dataloader_drop_last = True,  # Required for SPMD.
            fsdp="full_shard",
            fsdp_config=fsdp_config,
        ),
        peft_config=lora_config,
        max_seq_length=2048,
    )
    trainer.train()
    merged_model = trainer.model.merge_and_unload() # merge LORA with base model
    merged_model.to("cpu")
    merged_model.save_pretrained("output/merged")

    ### VERIFICATION, ENSURE THAT MODEL WAS TRAINED
    trained_model = AutoModelForCausalLM.from_pretrained("output/merged", torch_dtype=torch.bfloat16)
    original_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
    check_model_weights_equality(trained_model, original_model)

if __name__ == "__main__":
    train()

And this is the output

 cd /home/raix/minefinetune ; /usr/bin/env /home/raix/miniconda3/envs/v_xla/bin/python /home/raix/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 54351 -- /home/raix/minefinetune/server/train.py 
torch.__version__='2.3.0.dev20240307'
torch_xla.__version__='2.3.0+git46e2230'
peft.__version__='0.9.0'
trl.__version__='0.7.12.dev0'
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.62it/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710424188.529494  727042 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710424188.529574  727042 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710424188.529582  727042 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
  warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/raix/trl/trl/trainer/sft_trainer.py:300: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
  0%|                                                                                               | 0/28 [00:00<?, ?it/s]/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/nn/modules/module.py:1597: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1709797140173/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
{'train_runtime': 79.372, 'train_samples_per_second': 22.577, 'train_steps_per_second': 0.353, 'train_loss': 5.650647844587054, 'epoch': 4.0}
100%|██████████████████████████████████████████████████████████████████████████████████████| 28/28 [01:19<00:00,  2.83s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.10it/s]
Some weights of the model checkpoint at output/merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at output/merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.60it/s]
models <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> and <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> are the same

I'm stuck so, I'm asking for help. I tried many combinations of the PeftModel.merge_and_unload(), saving_pretrained(), and trainer.save_model() and nothing seems to work. Every idea to push this issue forward will be appreciated. Thanks.

Expected behavior

Training trains the model.

@PawKanarek
Copy link
Author

PawKanarek commented Mar 14, 2024

I modified the code a little bit to make some sanity checks.

def train():
    gemma2it = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it") # sanity check model
    
    tokenizer =  AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
    dataset = load_dataset("pawkanarek/poke_test", split="train")
    lora_config = LoraConfig(r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM")
    fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        tokenizer = tokenizer,
        args=TrainingArguments(
            per_device_train_batch_size=64,
            num_train_epochs=4,
            output_dir="output/trained_model",
            optim="adafactor",
            dataloader_drop_last = True,  # Required for SPMD.
            fsdp="full_shard",
            fsdp_config=fsdp_config,
        ),
        peft_config=lora_config,
        max_seq_length=2048,
    )
    # 1
    trainer.train()
    print("comparing gemma2it with trainer.model")
    compare_weights(gemma2it, trainer.model) # different GemmaForCausalLM:2506172416 params vs SpmdFullyShardedDataParallel:3031123968 params
    
    # 2
    merged_model = trainer.model.merge_and_unload()
    print("comparing gemma2it with merged_model")
    compare_weights(gemma2it, merged_model) # different GemmaForCausalLM:2506172416 params vs GemmaForCausalLM:3030460416 params
    
    # 3
    print("saving merged_model")
    merged_model.to("cpu")
    merged_model.save_pretrained("output/merged_model")
    compare_weights(gemma2it, merged_model) # different GemmaForCausalLM:2506172416 params vs GemmaForCausalLM:3030460416 params

    # 4
    print("comparing loaded merged_model from disk with in-memory merged_model")
    loaded_merged_model = AutoModelForCausalLM.from_pretrained("output/merged_model")
    compare_weights(merged_model, loaded_merged_model) # different GemmaForCausalLM:3030460416 params vs GemmaForCausalLM:2506172416 params

    # 5
    print("comparing gemma2it with loaded merged_model from disk")
    compare_weights(gemma2it, loaded_merged_model) # models  GemmaForCausalLM and GemmaForCausalLM are the same

I added some sanity checks with base, untouched gemma2it model, and some mid-step comparison:

  1. Check if model after training trainer.model, differs from the base gemma2it: yes, the are different in number of parameters - that implies that training was succesfull
  2. Check if trained model after merge merged_model, differs from the base gemma2it : yes, the are different in numer of parameters - that implies that merging was succesfull
  3. Saving merged model and check if model after save merged_model difffers from the base gemma2it : yes, the are different in number of parameters - that implies that saving does nothing to parameters
  4. Loading merged model from the disk loaded_merged_model and check if it differs from the merged_model before saving - YES THEY ARE DIFFERENT :( - that implies that there is something wrong with loading the model (or saving)
    4.1. This warning popped when loading model from the disk:
Some weights of the model checkpoint at output/merged_model were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', (...) 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at output/merged_model and are newly initialized: ['model.layers.0.input_layernorm.weight', (...) 'model.layers.9.self_attn.v_proj.weight']
(...)You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1. Check if merged model from disk loaded_merged_model differs from the base gemma2it: no, they are the same... - that implies that all my training was worthless...

Looks like there is something fishy with my code when saving / loading model from the disk... I'll update if i notice what's wrong. I will check why my weights are saved to something called _orig_module.

@PawKanarek PawKanarek changed the title Fine-tuning gemma-2b-it on TPU v3-8 seems not working Problems with saving standalone gemma-2b-it after fine-tuning with LoRA on TPU v3-8 Mar 14, 2024
@zorrofox
Copy link

Hi @PawKanarek

Please reference #29388 , by the way do you have testing the LoRA fine tune performance on TPU XLA? I have some explore for some LoRA but it has no any effective for the base model and the generate message just very same as base model.

@PawKanarek
Copy link
Author

Hi @zorrofox, and thanks for insight! Looks like my transformers fork didn't included change from that PR.
What kind of fine-tune performance are you talking about? You want to know how long does it take to train model with LoRA, or how well model is behaving after fine-tuning?

@PawKanarek
Copy link
Author

I used the trainer.save_pretrained function mentioned in PR #29388 but it didn't change anything - trained model after saving is still excactly the same as before training.

@PawKanarek
Copy link
Author

I think that i fixed it, but i won't recommend this fix to anyone, so I'm not even thinking about making PR.

It's a patch rather than fix, but i think it works - To check if it really works I will train gemma-2-it until it overfit on training dataset and then i will take a look on interference output.

To apply my patch you would have to add new parameter to save_pretrained
https://github.com/huggingface/transformers/blob/f02aea27378dd57c2ced4b28ff9e58ec3876340a/src/transformers/modeling_utils.py#L2190C1-L2203C7

formatting_weights_func = None,

Also add this code before sharding https://github.com/huggingface/transformers/blob/03847ef45189d328a51f428b0a61a6b891e69f88/src/transformers/modeling_utils.py#L2429C1-L2437C111

# apply formatting to the weights before saving 
if formatting_weights_func is not None: 
    for old_key in list(state_dict.keys()):
        new_key = formatting_weights_func(old_key)
        logger.debug(f"changed {old_key=} to {new_key=}")
        state_dict[new_key] = state_dict.pop(old_key)

With this changes I can finally spot a difference between a trained model loaded from disk and a base model that was trained on, and the warning also is gone

Some weights of the model checkpoint at output/merged_model were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', (...) 'model.layers.9._orig_module.self_attn.v_proj.weight']
def compare_weights(model1, model2):
    name1, name2 = model1.__class__.__name__, model2.__class__.__name__
    params1, params2 = model1.parameters(), model2.parameters() 
    sum1, sum2 = sum(p.numel() for p in params1), sum(p.numel() for p in params2)
    
    if (sum1 != sum2):
        print(f"!!! different in {name1}:{sum1} params vs {name2}:{sum2} params")
    
    for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()):
        if n1 != n2:
            print(f"!!! Parameter names differ: {n1} != {n2}")
            return False
        if not torch.equal(p1.data, p2.data):
            print(f"!!! Parameter values differ: {n1}, {p1.data}, {p2.data}")
            return False
        
def formmating_func(old_key):
    return old_key.replace('._orig_module', '')

def train():
    # the same training config as before
    trainer.train()
    trainer_model = trainer.model.to('cpu')
    merged_model = trainer_model.merge_and_unload()
    merged_model.save_pretrained("output/merged_model", formatting_weights_func = formmating_func)
    
    loaded_merged_model = AutoModelForCausalLM.from_pretrained("output/merged_model")
    gemma2it = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
    print("!!! comparing gemma2it with loaded merged_model from disk")
    compare_weights(gemma2it, loaded_merged_model) # !!! FINALLY !!! Parameter values differ: model.layers.0.self_attn.k_proj.weight, tensor([[-3.2043e-04,  8.1177e-03,  3.0365e-03,  ..., -5.3101e-03,

I'm not closing this issue, because I didn't fixed it, and true issue is still hidden somewhere. That's only workaround

@zorrofox
Copy link

@PawKanarek Thanks a lot for your advice, I also have the same issue as you. I think you have the root causes that why the trained model not changed.

@amyeroberts
Copy link
Collaborator

cc @pacman100 @muellerzr @shub-kris

@shub-kris
Copy link
Contributor

@PawKanarek just to isolate the error, what happens if you run the same code on a GPU instead of TPU?

@shub-kris
Copy link
Contributor

shub-kris commented Mar 18, 2024

@PawKanarek can you also provide the training logs please and run with logging_steps=1?
Also use save_strategy=epoch

@shub-kris
Copy link
Contributor

@PawKanarek also after training can you try saving with trainer.save_model('output_dir')

@shub-kris
Copy link
Contributor

@PawKanarek also with your patch did it work?

@shub-kris
Copy link
Contributor

@PawKanarek one last thing that I would like to see is: does the generation differs when using this: model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", torch_dtype=torch.float16) for generation on a GPU

@PawKanarek
Copy link
Author

PawKanarek commented Mar 18, 2024

@shub-kris thanks,

@PawKanarek just to isolate the error, what happens if you run the same code on a GPU instead of TPU?

I don't have GPU capable of training Gemma-2b-it model. I have only my local macbook with mps backend and Google TPU clouds (thanks to https://sites.research.google/trc/about/)

@PawKanarek can you also provide the training logs please and run with logging_steps=1?
Also use save_strategy=epoch

I will try to give you logs tomorrow. Today the machine is busy with training :)

@PawKanarek also after training can you try saving with trainer.save_model('output_dir')

I tried it many times, no success.

@PawKanarek also with your #29659 (comment) did it work?

Yes. It works.

@PawKanarek one last thing that I would like to see is: does the generation differs when using this: model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", torch_dtype=torch.float16) for generation on a GPU

Sadly, I don't have nvidia GPU.

@shub-kris
Copy link
Contributor

@PawKanarek thanks for your answers. I am having a look and will post here once I get to the root of the issue

@shub-kris
Copy link
Contributor

shub-kris commented Mar 18, 2024

I tried the following script on a GPU

import torch
import peft
import trl
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer

print(f"{torch.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")

def check_model_weights_equality(model1, model2):
    params1, params2 = model1.parameters(), model2.parameters() 
    sum1 = sum(p.numel() for p in params1)
    sum2 = sum(p.numel() for p in params2)
    
    if (sum1 != sum2):
        print(f"Number of parameters are different in {model1.__class__}:{sum1} and {model2.__class__}:{sum2} are different")
        return False
    
    for p1, p2 in zip(params1, params2):
        if not torch.equal(p1, p2):
            print(f"weights of {model1.__class__} and {model2.__class__} are different")
            return False
    
    print(f"models {model1.__class__} and {model2.__class__} are the same")
    return True

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


def train():
    model_id = "google/gemma-2b"
    tokenizer =  AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    dataset = load_dataset("pawkanarek/poke_test", split="train")
    lora_config = LoraConfig(r=16, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM", lora_alpha=16, lora_dropout=0.05,)
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        tokenizer = tokenizer,
        args=TrainingArguments(
            per_device_train_batch_size=2,
            max_steps=40, # small epochs for brevity, but the same is also with larger epochs
            output_dir="output/trained_model",
            optim="adafactor",
            logging_steps=1,
            learning_rate=3e-4,
        ),
        peft_config=lora_config,
        max_seq_length=512,
    )
    trainer.train()
    trainer.save_model()
    
    merged_model = trainer.model.merge_and_unload() # merge LORA with base model
    merged_model.to("cpu")
    print(type(merged_model), count_parameters(merged_model))
    merged_model.save_pretrained("adapters_merged")

    ### VERIFICATION, ENSURE THAT MODEL WAS TRAINED
    trained_model = AutoModelForCausalLM.from_pretrained("adapters_merged")
    print(type(trained_model), count_parameters(trained_model))
    original_model = AutoModelForCausalLM.from_pretrained(model_id)
    print(type(original_model), count_parameters(original_model))
    check_model_weights_equality(trained_model, original_model)

if __name__ == "__main__":
    train()

And here was the output:

[2024-03-18 20:15:02,900] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
torch.__version__='2.1.0a0+32f93b1'
peft.__version__='0.8.2'
trl.__version__='0.7.10'
Loading checkpoint shards: 100%|███████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.38it/s]
/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:290: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
{'loss': 2.8281, 'grad_norm': 0.66015625, 'learning_rate': 0.00029249999999999995, 'epoch': 0.0}                 
{'loss': 2.7031, 'grad_norm': 0.6171875, 'learning_rate': 0.000285, 'epoch': 0.01}                               
{'loss': 2.7344, 'grad_norm': 0.765625, 'learning_rate': 0.00027749999999999997, 'epoch': 0.01}                  
{'loss': 2.5469, 'grad_norm': 0.8125, 'learning_rate': 0.00027, 'epoch': 0.02}                                   
{'loss': 2.4688, 'grad_norm': 0.96484375, 'learning_rate': 0.0002625, 'epoch': 0.02}                             
{'loss': 2.3906, 'grad_norm': 0.90625, 'learning_rate': 0.00025499999999999996, 'epoch': 0.03}                   
{'loss': 2.4219, 'grad_norm': 1.1015625, 'learning_rate': 0.00024749999999999994, 'epoch': 0.03}                 
{'loss': 2.2344, 'grad_norm': 0.9296875, 'learning_rate': 0.00023999999999999998, 'epoch': 0.03}                 
{'loss': 2.2031, 'grad_norm': 1.015625, 'learning_rate': 0.00023249999999999999, 'epoch': 0.04}                  
{'loss': 2.0312, 'grad_norm': 0.96484375, 'learning_rate': 0.000225, 'epoch': 0.04}                              
{'loss': 2.0938, 'grad_norm': 1.1015625, 'learning_rate': 0.00021749999999999997, 'epoch': 0.05}                 
{'loss': 2.0, 'grad_norm': 1.296875, 'learning_rate': 0.00020999999999999998, 'epoch': 0.05}                     
{'loss': 1.8281, 'grad_norm': 3.078125, 'learning_rate': 0.0002025, 'epoch': 0.06}                               
{'loss': 1.7656, 'grad_norm': 1.9609375, 'learning_rate': 0.000195, 'epoch': 0.06}                               
{'loss': 1.7031, 'grad_norm': 3.859375, 'learning_rate': 0.00018749999999999998, 'epoch': 0.07}                  
{'loss': 1.6484, 'grad_norm': 2.171875, 'learning_rate': 0.00017999999999999998, 'epoch': 0.07}                  
{'loss': 1.5859, 'grad_norm': 2.453125, 'learning_rate': 0.00017249999999999996, 'epoch': 0.07}                  
{'loss': 1.5312, 'grad_norm': 1.96875, 'learning_rate': 0.000165, 'epoch': 0.08}                                 
{'loss': 1.5391, 'grad_norm': 1.8671875, 'learning_rate': 0.00015749999999999998, 'epoch': 0.08}                 
{'loss': 1.3828, 'grad_norm': 2.109375, 'learning_rate': 0.00015, 'epoch': 0.09}                                 
{'loss': 1.3438, 'grad_norm': 3.609375, 'learning_rate': 0.0001425, 'epoch': 0.09}                               
{'loss': 1.2969, 'grad_norm': 2.671875, 'learning_rate': 0.000135, 'epoch': 0.1}                                 
{'loss': 1.2344, 'grad_norm': 3.328125, 'learning_rate': 0.00012749999999999998, 'epoch': 0.1}                   
{'loss': 1.2891, 'grad_norm': 2.9375, 'learning_rate': 0.00011999999999999999, 'epoch': 0.1}                     
{'loss': 1.2656, 'grad_norm': 2.109375, 'learning_rate': 0.0001125, 'epoch': 0.11}                               
{'loss': 1.0938, 'grad_norm': 2.890625, 'learning_rate': 0.00010499999999999999, 'epoch': 0.11}                  
{'loss': 1.0391, 'grad_norm': 2.46875, 'learning_rate': 9.75e-05, 'epoch': 0.12}                                 
{'loss': 1.1016, 'grad_norm': 2.859375, 'learning_rate': 8.999999999999999e-05, 'epoch': 0.12}                   
{'loss': 1.0625, 'grad_norm': 2.421875, 'learning_rate': 8.25e-05, 'epoch': 0.13}                                
{'loss': 0.957, 'grad_norm': 2.4375, 'learning_rate': 7.5e-05, 'epoch': 0.13}                                    
{'loss': 0.9219, 'grad_norm': 1.703125, 'learning_rate': 6.75e-05, 'epoch': 0.13}                                
{'loss': 0.8906, 'grad_norm': 1.7734375, 'learning_rate': 5.9999999999999995e-05, 'epoch': 0.14}                 
{'loss': 0.9609, 'grad_norm': 4.40625, 'learning_rate': 5.2499999999999995e-05, 'epoch': 0.14}                   
{'loss': 0.875, 'grad_norm': 2.109375, 'learning_rate': 4.4999999999999996e-05, 'epoch': 0.15}                   
{'loss': 0.9219, 'grad_norm': 2.8125, 'learning_rate': 3.75e-05, 'epoch': 0.15}                                  
{'loss': 0.9102, 'grad_norm': 2.125, 'learning_rate': 2.9999999999999997e-05, 'epoch': 0.16}                     
{'loss': 0.9258, 'grad_norm': 1.515625, 'learning_rate': 2.2499999999999998e-05, 'epoch': 0.16}                  
{'loss': 0.8164, 'grad_norm': 1.8515625, 'learning_rate': 1.4999999999999999e-05, 'epoch': 0.17}                 
{'loss': 0.8164, 'grad_norm': 2.0, 'learning_rate': 7.499999999999999e-06, 'epoch': 0.17}                        
{'loss': 0.8086, 'grad_norm': 1.6484375, 'learning_rate': 0.0, 'epoch': 0.17}                                    
{'train_runtime': 5.337, 'train_samples_per_second': 14.99, 'train_steps_per_second': 7.495, 'train_loss': 1.554296875, 'epoch': 0.17}
100%|████████████████████████████████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.50it/s]
<class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> 2506172416
Loading checkpoint shards: 100%|███████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.78it/s]
<class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> 2506172416
Loading checkpoint shards: 100%|███████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.81it/s]
<class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> 2506172416
models <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> and <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> are the same

so, the issue has nothing to do with TPU for sure
cc @amyeroberts @pacman100 @muellerzr

However, one thing I would like to verify is if your way of checking if the model weights are equal or not. So, will get back to you on that.

@moficodes
Copy link

trainer.save_model(new_model_id)
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.bfloat16,
)

newmodel = PeftModel.from_pretrained(base_model, new_model_id)
newmodel = newmodel.merge_and_unload()

print(check_model_weights_equality(model, newmodel))

Logs:

Number of parameters are different in <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'>:9327324160 and <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'>:8537680896 are different
False

@PawKanarek
Copy link
Author

PawKanarek commented Mar 19, 2024

@moficodes I think you did misunderstand my intentions. I want to save a standalone model, not just the LoRA adapter. You saved only the LoRA adapter (with trainer.save_model()), but I there is problem with loading/saving the merged model after merge_and_unload()

Please Take a look at this updated script. I changed a comparing function to be more descriptive, and I added more logging as @shub-kris asked.

# Make sure to run the script with the following envs:
#   PJRT_DEVICE=TPU XLA_USE_SPMD=1
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft.peft_model import PeftModel
from trl import SFTTrainer
from transformers import logging, IntervalStrategy

device = xm.xla_device() # Set up TPU device.

def models_equal(model1, model2):
    name1, name2 = model1.__class__.__name__, model2.__class__.__name__
    params1, params2 = model1.parameters(), model2.parameters() 
    sum1, sum2 = sum(p.numel() for p in params1), sum(p.numel() for p in params2)
    
    if (sum1 != sum2):
        print(f"!!! numer of params are different in {name1}:{sum1} params vs {name2}:{sum2} params")
    
    for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()):
        if n1 != n2:
            print(f"!!! Parameter names differ: {n1} != {n2}")
            return False
        if not torch.equal(p1.data, p2.data):
            print(f"!!! Parameter values differ: {n1}, {p1.data}, {p2.data}")
            return False
        
    print(f"!!! models {name1} and {name2} are the same")
    return True

def train():
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
    dataset = load_dataset("pawkanarek/poke_test", split="train")
    lora_config = LoraConfig(r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM")
    fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        tokenizer = tokenizer,
        args=TrainingArguments(
            logging_steps=1, 
            save_strategy=IntervalStrategy.EPOCH,
            per_device_train_batch_size=64,
            num_train_epochs=1,
            output_dir="output/trained_model",
            optim="adafactor",
            dataloader_drop_last = True,  # Required for SPMD.
            fsdp="full_shard",
            fsdp_config=fsdp_config,
        ),
        peft_config=lora_config,
        max_seq_length=2048,
    )
    trainer.train()
    trainer.save_model()
    
    base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", return_dict=True, torch_dtype=torch.bfloat16)
    new_model = PeftModel.from_pretrained(base_model, "output/trained_model")
    new_model = new_model.merge_and_unload()
    new_model.save_pretrained("output/new_model")
    
    new_model_from_disk = AutoModelForCausalLM.from_pretrained("output/new_model", torch_dtype=torch.bfloat16)
    base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
    print(f"are equal after load from disk? {models_equal(base_model, new_model_from_disk)}") # they equal after loading from disk 
    print(1)


if __name__ == "__main__":
    logging.set_verbosity(logging.DEBUG)
    train()
and output: (click on arrow)
(v_xla) raix@t1v-n-3a1a9ef8-w-0:~/minefinetune$  cd /home/raix/minefinetune ; /usr/bin/env /home/raix/miniconda3/envs/v_xla/bin/python /home/raix/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 35991 -- /home/raix/minefinetune/server/train.py 
loading configuration file config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/config.json
Model config GemmaConfig {
  "_name_or_path": "google/gemma-2b-it",
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.0.dev0",
  "use_cache": true,
  "vocab_size": 256000
}

loading weights file model.safetensors from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/model.safetensors.index.json
Instantiating GemmaForCausalLM model under default dtype torch.bfloat16.
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.53it/s]
All model checkpoint weights were used when initializing GemmaForCausalLM.

All the weights of GemmaForCausalLM were initialized from the model checkpoint at google/gemma-2b-it.
If your task is similar to the task the model of the checkpoint was trained on, you can already use GemmaForCausalLM for predictions without further training.
loading configuration file generation_config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/generation_config.json
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

loading file tokenizer.model from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/tokenizer.model
loading file tokenizer.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/special_tokens_map.json
loading file tokenizer_config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/tokenizer_config.json
PyTorch: setting up devices
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710807544.909560 1196716 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710807544.909638 1196716 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710807544.909646 1196716 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
  warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/raix/trl/trl/trainer/sft_trainer.py:316: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
Currently training with a batch size of: 64
***** Running training *****
  Num examples = 448
  Num Epochs = 1
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 7
  Number of trainable parameters = 663,552
  0%|                                                                                                         | 0/7 [00:00<?, ?it/s]/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/nn/modules/module.py:1597: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1709797140173/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
{'loss': 5.6564, 'grad_norm': 0.3203125, 'learning_rate': 4.2857142857142856e-05, 'epoch': 0.14}                                    
{'loss': 5.6915, 'grad_norm': 0.322265625, 'learning_rate': 3.571428571428572e-05, 'epoch': 0.29}                                   
{'loss': 5.6364, 'grad_norm': 0.333984375, 'learning_rate': 2.857142857142857e-05, 'epoch': 0.43}                                   
{'loss': 5.6424, 'grad_norm': 0.34765625, 'learning_rate': 2.1428571428571428e-05, 'epoch': 0.57}                                   
{'loss': 5.6617, 'grad_norm': 0.3515625, 'learning_rate': 1.4285714285714285e-05, 'epoch': 0.71}                                    
{'loss': 5.6785, 'grad_norm': 0.35546875, 'learning_rate': 7.142857142857143e-06, 'epoch': 0.86}                                    
{'loss': 5.6422, 'grad_norm': 0.35546875, 'learning_rate': 0.0, 'epoch': 1.0}                                                       
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:09<00:00,  5.12s/it]Saving model checkpoint to output/trained_model/checkpoint-7
loading configuration file config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/config.json
Model config GemmaConfig {
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.0.dev0",
  "use_cache": true,
  "vocab_size": 256000
}

tokenizer config file saved in output/trained_model/checkpoint-7/tokenizer_config.json
Special tokens file saved in output/trained_model/checkpoint-7/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 104.6427, 'train_samples_per_second': 4.281, 'train_steps_per_second': 0.067, 'train_loss': 5.658452238355364, 'epoch': 1.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:44<00:00, 14.94s/it]
Saving model checkpoint to output/trained_model
loading configuration file config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/config.json
Model config GemmaConfig {
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.0.dev0",
  "use_cache": true,
  "vocab_size": 256000
}

tokenizer config file saved in output/trained_model/tokenizer_config.json
Special tokens file saved in output/trained_model/special_tokens_map.json
loading configuration file config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/config.json
Model config GemmaConfig {
  "_name_or_path": "google/gemma-2b-it",
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.0.dev0",
  "use_cache": true,
  "vocab_size": 256000
}

loading weights file model.safetensors from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/model.safetensors.index.json
Instantiating GemmaForCausalLM model under default dtype torch.bfloat16.
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.56it/s]
All model checkpoint weights were used when initializing GemmaForCausalLM.

All the weights of GemmaForCausalLM were initialized from the model checkpoint at google/gemma-2b-it.
If your task is similar to the task the model of the checkpoint was trained on, you can already use GemmaForCausalLM for predictions without further training.
loading configuration file generation_config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/generation_config.json
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

Configuration saved in output/new_model/config.json
Configuration saved in output/new_model/generation_config.json
The model is bigger than the maximum size per checkpoint (5GB) and is going to be split in 2 checkpoint shards. You can find where each parameters has been saved in the index located at output/new_model/model.safetensors.index.json.
loading configuration file output/new_model/config.json
Model config GemmaConfig {
  "_name_or_path": "output/new_model",
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.0.dev0",
  "use_cache": true,
  "vocab_size": 256000
}

loading weights file output/new_model/model.safetensors.index.json
Instantiating GemmaForCausalLM model under default dtype torch.bfloat16.
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.87it/s]
All model checkpoint weights were used when initializing GemmaForCausalLM.

All the weights of GemmaForCausalLM were initialized from the model checkpoint at output/new_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use GemmaForCausalLM for predictions without further training.
loading configuration file output/new_model/generation_config.json
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

loading configuration file config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/config.json
Model config GemmaConfig {
  "_name_or_path": "google/gemma-2b-it",
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.0.dev0",
  "use_cache": true,
  "vocab_size": 256000
}

loading weights file model.safetensors from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/model.safetensors.index.json
Instantiating GemmaForCausalLM model under default dtype torch.bfloat16.
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.83it/s]
All model checkpoint weights were used when initializing GemmaForCausalLM.

All the weights of GemmaForCausalLM were initialized from the model checkpoint at google/gemma-2b-it.
If your task is similar to the task the model of the checkpoint was trained on, you can already use GemmaForCausalLM for predictions without further training.
loading configuration file generation_config.json from cache at /home/raix/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/generation_config.json
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

!!! models GemmaForCausalLM and GemmaForCausalLM are the same
are equal after load from disk? True
1

As you can see at the end i again see information that the base model and loaded model from disk are the same

!!! models GemmaForCausalLM and GemmaForCausalLM are the same
are equal after load from disk? True

I'm open to investigate further.

@shub-kris
Copy link
Contributor

shub-kris commented Mar 19, 2024

Hi @PawKanarek I tried a new script which is very similar to your script, and I tried inference before and after training the models and the results are different, which verifies that the model was trained and also saved perfectly.

Script

#   PJRT_DEVICE=TPU XLA_USE_SPMD=1
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer

print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")


device = xm.xla_device() # Set up TPU device.

def inference(model, tokenizer):
    text = "Quote: Imagination is more"
    device = "cpu"
    inputs = tokenizer(text, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_new_tokens=20) #generate only supported on GPU and CPU
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))


def train():
    model_id = "google/gemma-2b"
    
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    tokenizer =  AutoTokenizer.from_pretrained(model_id)
    # tokenizer.pad_token = tokenizer.eos_token
    
    #Load and process dataset
    raw_dataset = load_dataset("Abirate/english_quotes", split="train")
    lora_config = LoraConfig(r=8, target_modules="all-linear", task_type="CAUSAL_LM", lora_alpha=16, lora_dropout=0.05,)
    fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}

    trainer = SFTTrainer(
        model=model,
        # train_dataset=format_dataset,
        train_dataset=raw_dataset,
        tokenizer = tokenizer,
        args=TrainingArguments(
            per_device_train_batch_size=32,
            num_train_epochs=10,
            output_dir="output",
            optim="adafactor",
            logging_steps=1,
            learning_rate=3e-4,
            save_strategy="no",
            dataloader_drop_last = True,  # Required for SPMD.
            fsdp="full_shard",
            fsdp_config=fsdp_config,
        ),
        peft_config=lora_config,
        max_seq_length=1024,
        packing=True,
        dataset_text_field="quote",
    )
    trainer.train()
    trainer.save_model()
    
    merged_model = trainer.model.merge_and_unload() # merge LORA with base model
    merged_model.to("cpu")
    merged_model.save_pretrained("adapters_merged")

    ### VERIFICATION, ENSURE THAT MODEL WAS TRAINED
    trained_model = AutoModelForCausalLM.from_pretrained("adapters_merged")    
    original_model = AutoModelForCausalLM.from_pretrained(model_id)

    print("Inference with base model: \n\n")
    inference(original_model, tokenizer)
    
    print("Inference with trained model: \n\n")
    inference(trained_model, tokenizer)
    
if __name__ == "__main__":
    train()

logs

torch.__version__='2.3.0'
torch_xla.__version__='2.3.0+gite385c2f'
peft.__version__='0.8.2'
trl.__version__='0.7.12.dev0'
{'loss': 5.0312, 'grad_norm': 3.109375, 'learning_rate': 0.00029, 'epoch': 0.33}                                                                                                                                                                                                                                                                                         
{'loss': 4.7812, 'grad_norm': 2.921875, 'learning_rate': 0.00028, 'epoch': 0.67}                                                                                                                                                                                                                                                                                         
{'loss': 4.5312, 'grad_norm': 4.15625, 'learning_rate': 0.00027, 'epoch': 1.0}                                                                                                                                                                                                                                                                                           
{'loss': 4.1875, 'grad_norm': 3.90625, 'learning_rate': 0.00026, 'epoch': 1.33}                                                                                                                                                                                                                                                                                          
{'loss': 3.9062, 'grad_norm': 4.46875, 'learning_rate': 0.00025, 'epoch': 1.67}                                                                                                                                                                                                                                                                                          
{'loss': 3.75, 'grad_norm': 4.15625, 'learning_rate': 0.00023999999999999998, 'epoch': 2.0}                                                                                                                                                                                                                                                                              
{'loss': 3.4688, 'grad_norm': 4.46875, 'learning_rate': 0.00023, 'epoch': 2.33}                                                                                                                                                                                                                                                                                          
{'loss': 3.3438, 'grad_norm': 3.71875, 'learning_rate': 0.00021999999999999995, 'epoch': 2.67}                                                                                                                                                                                                                                                                           
{'loss': 3.2656, 'grad_norm': 3.5, 'learning_rate': 0.00020999999999999998, 'epoch': 3.0}                                                                                                                                                                                                                                                                                
{'loss': 3.0781, 'grad_norm': 2.734375, 'learning_rate': 0.00019999999999999998, 'epoch': 3.33}                                                                                                                                                                                                                                                                          
{'loss': 3.0, 'grad_norm': 2.328125, 'learning_rate': 0.00018999999999999998, 'epoch': 3.67}                                                                                                                                                                                                                                                                             
{'loss': 2.9531, 'grad_norm': 1.796875, 'learning_rate': 0.00017999999999999998, 'epoch': 4.0}                                                                                                                                                                                                                                                                           
{'loss': 2.875, 'grad_norm': 2.5, 'learning_rate': 0.00016999999999999999, 'epoch': 4.33}                                                                                                                                                                                                                                                                                
{'loss': 2.8281, 'grad_norm': 3.15625, 'learning_rate': 0.00015999999999999999, 'epoch': 4.67}                                                                                                                                                                                                                                                                           
{'loss': 2.7969, 'grad_norm': 3.546875, 'learning_rate': 0.00015, 'epoch': 5.0}                                                                                                                                                                                                                                                                                          
{'loss': 2.7188, 'grad_norm': 1.4375, 'learning_rate': 0.00014, 'epoch': 5.33}                                                                                                                                                                                                                                                                                           
{'loss': 2.7188, 'grad_norm': 2.21875, 'learning_rate': 0.00013, 'epoch': 5.67}                                                                                                                                                                                                                                                                                          
{'loss': 2.7656, 'grad_norm': 3.40625, 'learning_rate': 0.00011999999999999999, 'epoch': 6.0}                                                                                                                                                                                                                                                                            
{'loss': 2.6875, 'grad_norm': 4.6875, 'learning_rate': 0.00010999999999999998, 'epoch': 6.33}                                                                                                                                                                                                                                                                            
{'loss': 2.625, 'grad_norm': 1.6015625, 'learning_rate': 9.999999999999999e-05, 'epoch': 6.67}                                                                                                                                                                                                                                                                           
{'loss': 2.6562, 'grad_norm': 1.546875, 'learning_rate': 8.999999999999999e-05, 'epoch': 7.0}                                                                                                                                                                                                                                                                            
{'loss': 2.6562, 'grad_norm': 1.703125, 'learning_rate': 7.999999999999999e-05, 'epoch': 7.33}                                                                                                                                                                                                                                                                           
{'loss': 2.5938, 'grad_norm': 1.40625, 'learning_rate': 7e-05, 'epoch': 7.67}                                                                                                                                                                                                                                                                                            
{'loss': 2.625, 'grad_norm': 1.1796875, 'learning_rate': 5.9999999999999995e-05, 'epoch': 8.0}                                                                                                                                                                                                                                                                           
{'loss': 2.6562, 'grad_norm': 1.5078125, 'learning_rate': 4.9999999999999996e-05, 'epoch': 8.33}                                                                                                                                                                                                                                                                         
{'loss': 2.5, 'grad_norm': 1.0234375, 'learning_rate': 3.9999999999999996e-05, 'epoch': 8.67}                                                                                                                                                                                                                                                                            
{'loss': 2.5156, 'grad_norm': 1.359375, 'learning_rate': 2.9999999999999997e-05, 'epoch': 9.0}                                                                                                                                                                                                                                                                           
{'loss': 2.5, 'grad_norm': 1.03125, 'learning_rate': 1.9999999999999998e-05, 'epoch': 9.33}                                                                                                                                                                                                                                                                              
{'loss': 2.5938, 'grad_norm': 1.125, 'learning_rate': 9.999999999999999e-06, 'epoch': 9.67}                                                                                                                                                                                                                                                                              
{'loss': 2.5, 'grad_norm': 0.97265625, 'learning_rate': 0.0, 'epoch': 10.0}                                                                                                                                                                                                                                                                                              
{'train_runtime': 386.8015, 'train_samples_per_second': 2.482, 'train_steps_per_second': 0.078, 'train_loss': 3.103645833333333, 'epoch': 10.0}   

Inference Results

  1. With original model
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein

I am
  1. With finetuned-model
Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa

@amyeroberts we can close this issue #29659 and also the issue #29608

@shub-kris
Copy link
Contributor

I also tried without FSDP as it is easier to finetune:

With finetuned model I got this result:

Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
Author: Albert Einstein

@PawKanarek
Copy link
Author

Thank you @shub-kris ! I will run this script on my local machine and then I will share the results.
I have one question regarding to your code, why do you set?

tokenizer.pad_token = tokenizer.eos_token

?

@shub-kris
Copy link
Contributor

shub-kris commented Mar 19, 2024

It configures the tokenizer's padding token to be the same as its end-of-sequence (EOS) token. But you don't need it for this use-case as the tokenizr already has pad_token defined here

@zorrofox
Copy link

zorrofox commented Mar 19, 2024

@shub-kris ,

I got the inference result from TPU is not like you

logs:

torch.__version__='2.3.0.dev20240312+cu121'
torch_xla.__version__='2.3.0+git97acc14'
peft.__version__='0.9.0'
trl.__version__='0.7.11'
config.json: 100%|███████████████████████████████████████████████████████████████████████████| 627/627 [00:00<00:00, 3.63MB/s]
model.safetensors.index.json: 100%|██████████████████████████████████████████████████████| 13.5k/13.5k [00:00<00:00, 49.6MB/s]
model-00001-of-00002.safetensors: 100%|███████████████████████████████████████████████████| 4.95G/4.95G [00:22<00:00, 218MB/s]
model-00002-of-00002.safetensors: 100%|███████████████████████████████████████████████████| 67.1M/67.1M [00:00<00:00, 222MB/s]
Downloading shards: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:23<00:00, 11.71s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.05it/s]
generation_config.json: 100%|█████████████████████████████████████████████████████████████████| 137/137 [00:00<00:00, 832kB/s]
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████████| 1.11k/1.11k [00:00<00:00, 7.57MB/s]
tokenizer.model: 100%|████████████████████████████████████████████████████████████████████| 4.24M/4.24M [00:00<00:00, 207MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████| 17.5M/17.5M [00:00<00:00, 223MB/s]
special_tokens_map.json: 100%|███████████████████████████████████████████████████████████████| 555/555 [00:00<00:00, 1.27MB/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710857864.890503    6669 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710857864.890575    6669 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710857864.890586    6669 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py:104: UserWarning: `devkind` argument is deprecated and will be removed in a future release.
  warnings.warn("`devkind` argument is deprecated and will be removed in a "
Generating train split: 102 examples [00:00, 159.49 examples/s]
/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:294: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: 
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
  warnings.warn(
  0%|                                                                                                  | 0/30 [00:00<?, ?it/s]/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1597: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at ../torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
{'loss': 4.273, 'grad_norm': 3.953125, 'learning_rate': 0.00029, 'epoch': 0.33}                                               
{'loss': 4.1232, 'grad_norm': 4.25, 'learning_rate': 0.00028, 'epoch': 0.67}                                                  
{'loss': 3.7796, 'grad_norm': 5.3125, 'learning_rate': 0.00027, 'epoch': 1.0}                                                 
{'loss': 3.4005, 'grad_norm': 4.59375, 'learning_rate': 0.00026, 'epoch': 1.33}                                               
{'loss': 3.2413, 'grad_norm': 3.046875, 'learning_rate': 0.00025, 'epoch': 1.67}                                              
{'loss': 2.9242, 'grad_norm': 1.9765625, 'learning_rate': 0.00023999999999999998, 'epoch': 2.0}                               
{'loss': 2.7689, 'grad_norm': 2.5625, 'learning_rate': 0.00023, 'epoch': 2.33}                                                
{'loss': 2.7829, 'grad_norm': 2.046875, 'learning_rate': 0.00021999999999999995, 'epoch': 2.67}                               
{'loss': 2.6584, 'grad_norm': 3.984375, 'learning_rate': 0.00020999999999999998, 'epoch': 3.0}                                
{'loss': 2.6561, 'grad_norm': 1.6171875, 'learning_rate': 0.00019999999999999998, 'epoch': 3.33}                              
{'loss': 2.5347, 'grad_norm': 5.34375, 'learning_rate': 0.00018999999999999998, 'epoch': 3.67}                                
{'loss': 2.4281, 'grad_norm': 2.328125, 'learning_rate': 0.00017999999999999998, 'epoch': 4.0}                                
{'loss': 2.4578, 'grad_norm': 3.015625, 'learning_rate': 0.00016999999999999999, 'epoch': 4.33}                               
{'loss': 2.5122, 'grad_norm': 1.4765625, 'learning_rate': 0.00015999999999999999, 'epoch': 4.67}                              
{'loss': 2.3117, 'grad_norm': 2.125, 'learning_rate': 0.00015, 'epoch': 5.0}                                                  
{'loss': 2.3832, 'grad_norm': 2.109375, 'learning_rate': 0.00014, 'epoch': 5.33}                                              
{'loss': 2.3193, 'grad_norm': 1.609375, 'learning_rate': 0.00013, 'epoch': 5.67}                                              
{'loss': 2.2856, 'grad_norm': 2.109375, 'learning_rate': 0.00011999999999999999, 'epoch': 6.0}                                
{'loss': 2.2524, 'grad_norm': 1.7421875, 'learning_rate': 0.00010999999999999998, 'epoch': 6.33}                              
{'loss': 2.2826, 'grad_norm': 1.328125, 'learning_rate': 9.999999999999999e-05, 'epoch': 6.67}                                
{'loss': 2.1978, 'grad_norm': 1.109375, 'learning_rate': 8.999999999999999e-05, 'epoch': 7.0}                                 
{'loss': 2.2295, 'grad_norm': 1.078125, 'learning_rate': 7.999999999999999e-05, 'epoch': 7.33}                                
{'loss': 2.1379, 'grad_norm': 1.21875, 'learning_rate': 7e-05, 'epoch': 7.67}                                                 
{'loss': 2.2398, 'grad_norm': 1.6171875, 'learning_rate': 5.9999999999999995e-05, 'epoch': 8.0}                               
{'loss': 2.1681, 'grad_norm': 0.890625, 'learning_rate': 4.9999999999999996e-05, 'epoch': 8.33}                               
{'loss': 2.176, 'grad_norm': 5.96875, 'learning_rate': 3.9999999999999996e-05, 'epoch': 8.67}                                 
{'loss': 2.1323, 'grad_norm': 0.89453125, 'learning_rate': 2.9999999999999997e-05, 'epoch': 9.0}                              
{'loss': 2.1921, 'grad_norm': 0.87109375, 'learning_rate': 1.9999999999999998e-05, 'epoch': 9.33}                             
{'loss': 2.0294, 'grad_norm': 5.625, 'learning_rate': 9.999999999999999e-06, 'epoch': 9.67}                                   
{'loss': 2.1877, 'grad_norm': 0.73046875, 'learning_rate': 0.0, 'epoch': 10.0}                                                
{'train_runtime': 567.8682, 'train_samples_per_second': 1.691, 'train_steps_per_second': 0.053, 'train_loss': 2.6022108157475787, 'epoch': 10.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████| 30/30 [09:27<00:00, 18.93s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.31it/s]
Some weights of the model checkpoint at adapters_merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at adapters_merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.33it/s]
Inference with base model: 


Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein

I am
Inference with trained model: 


Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa

@shub-kris
Copy link
Contributor

@zorrofox it's like me: #29659 (comment)

@zorrofox
Copy link

@zorrofox it's like me: #29659 (comment)

But the inference result is very diffrent.

@shub-kris
Copy link
Contributor

shub-kris commented Mar 19, 2024

@zorrofox nothing is different. Please go through the comment once again, and if it's different what is different?

Are you referring to this comment: #29659 (comment) then here i tried without FSDP.

@PawKanarek
Copy link
Author

I think that my original method for comparing weights was broken. When I accessed the parameters with the params1 = model1.parameters()
Then the method returns iterator function, and it will only iterate once. And in my original comparing function I accessed it twice, so the my original function for model comparing was buggy...:( Look at this sample code

params1 = model1.parameters()
print(len(list(params1))) # prints 164
print(len(list(params1))) # prints 0 

I tried your code @shub-kris and I have exactly the same result from merged model:

Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa

That looks kinda broken, and i still experience this warning when loading merged model

Some weights of the model checkpoint at adapters_merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at adapters_merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

But maybe this should be addressed in another issue. Thanks once more for investigating and debugging.

@amyeroberts
Copy link
Collaborator

@PawKanarek

That looks kinda broken, and i still experience this warning when loading merged model

Is this happening when you're loading a saved model?

@PawKanarek
Copy link
Author

Is this happening when you're loading a saved model?

@amyeroberts No, I copied that warning message from comment of @zorrofox #29659 (comment), but I remember that i also experienced this warning.

To be 100% certain, I once again launched code from this comment of @shub-kris #29659 (comment) and thats is my output

torch.__version__='2.3.0.dev20240307'
torch_xla.__version__='2.3.0+git46e2230'
peft.__version__='0.9.0'
trl.__version__='0.7.12.dev0'
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.67it/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710873891.917161 1297506 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710873891.917242 1297506 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710873891.917250 1297506 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
  warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/raix/trl/trl/trainer/sft_trainer.py:316: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
  0%|                                                                                                        | 0/30 [00:00<?, ?it/s]/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/nn/modules/module.py:1597: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1709797140173/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
{'loss': 4.254, 'grad_norm': 3.453125, 'learning_rate': 0.00029, 'epoch': 0.33}                                                     
{'loss': 4.1319, 'grad_norm': 3.78125, 'learning_rate': 0.00028, 'epoch': 0.67}                                                     
{'loss': 3.8043, 'grad_norm': 4.5, 'learning_rate': 0.00027, 'epoch': 1.0}                                                          
{'loss': 3.4729, 'grad_norm': 3.859375, 'learning_rate': 0.00026, 'epoch': 1.33}                                                    
{'loss': 3.1394, 'grad_norm': 3.375, 'learning_rate': 0.00025, 'epoch': 1.67}                                                       
{'loss': 2.9524, 'grad_norm': 2.015625, 'learning_rate': 0.00023999999999999998, 'epoch': 2.0}                                      
{'loss': 2.8268, 'grad_norm': 1.703125, 'learning_rate': 0.00023, 'epoch': 2.33}                                                    
{'loss': 2.6656, 'grad_norm': 1.4609375, 'learning_rate': 0.00021999999999999995, 'epoch': 2.67}                                    
{'loss': 2.7338, 'grad_norm': 4.21875, 'learning_rate': 0.00020999999999999998, 'epoch': 3.0}                                       
{'loss': 2.6369, 'grad_norm': 2.40625, 'learning_rate': 0.00019999999999999998, 'epoch': 3.33}                                      
{'loss': 2.5441, 'grad_norm': 2.21875, 'learning_rate': 0.00018999999999999998, 'epoch': 3.67}                                      
{'loss': 2.4651, 'grad_norm': 2.6875, 'learning_rate': 0.00017999999999999998, 'epoch': 4.0}                                        
{'loss': 2.3907, 'grad_norm': 11.375, 'learning_rate': 0.00016999999999999999, 'epoch': 4.33}                                       
{'loss': 2.3174, 'grad_norm': 3.875, 'learning_rate': 0.00015999999999999999, 'epoch': 4.67}                                        
{'loss': 2.489, 'grad_norm': 1.609375, 'learning_rate': 0.00015, 'epoch': 5.0}                                                      
{'loss': 2.2825, 'grad_norm': 1.4921875, 'learning_rate': 0.00014, 'epoch': 5.33}                                                   
{'loss': 2.3592, 'grad_norm': 2.3125, 'learning_rate': 0.00013, 'epoch': 5.67}                                                      
{'loss': 2.4066, 'grad_norm': 1.859375, 'learning_rate': 0.00011999999999999999, 'epoch': 6.0}                                      
{'loss': 2.2769, 'grad_norm': 2.515625, 'learning_rate': 0.00010999999999999998, 'epoch': 6.33}                                     
{'loss': 2.2699, 'grad_norm': 1.65625, 'learning_rate': 9.999999999999999e-05, 'epoch': 6.67}                                       
{'loss': 2.267, 'grad_norm': 1.4765625, 'learning_rate': 8.999999999999999e-05, 'epoch': 7.0}                                       
{'loss': 2.0841, 'grad_norm': 1.21875, 'learning_rate': 7.999999999999999e-05, 'epoch': 7.33}                                       
{'loss': 2.3272, 'grad_norm': 2.0625, 'learning_rate': 7e-05, 'epoch': 7.67}                                                        
{'loss': 2.2218, 'grad_norm': 2.6875, 'learning_rate': 5.9999999999999995e-05, 'epoch': 8.0}                                        
{'loss': 2.1625, 'grad_norm': 0.74609375, 'learning_rate': 4.9999999999999996e-05, 'epoch': 8.33}                                   
{'loss': 2.1687, 'grad_norm': 1.203125, 'learning_rate': 3.9999999999999996e-05, 'epoch': 8.67}                                     
{'loss': 2.153, 'grad_norm': 7.65625, 'learning_rate': 2.9999999999999997e-05, 'epoch': 9.0}                                        
{'loss': 2.1273, 'grad_norm': 1.359375, 'learning_rate': 1.9999999999999998e-05, 'epoch': 9.33}                                     
{'loss': 2.1455, 'grad_norm': 3.015625, 'learning_rate': 9.999999999999999e-06, 'epoch': 9.67}                                      
{'loss': 2.2011, 'grad_norm': 1.0078125, 'learning_rate': 0.0, 'epoch': 10.0}                                                       
{'train_runtime': 250.5815, 'train_samples_per_second': 3.831, 'train_steps_per_second': 0.12, 'train_loss': 2.6092514197031655, 'epoch': 10.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [04:10<00:00,  8.35s/it]
tcmalloc: large alloc 2097152000 bytes == 0x52ee08000 @  0x7f0701396680 0x7f07013b7824 0x7f07013b7b8a 0x7f06e82d38e4 0x7f06e8298d03 0x7f06e9836af9 0x7f06e9830754 0x7f06e983079f 0x7f06e98307e5 0x7f06e9fcaf90 0x7f06eac15c91 0x7f06eac15ceb 0x7f06ea832d67 0x7f06eabdc25f 0x7f06ea87ad80 0x7f06f4dfaf12 0x4fc697 0x5089a9 0x4f2a14 0x4f561d 0x505be8 0x4f619b 0x4f40b0 0x4f561d 0x505be8 0x4f619b 0x4f434a 0x4f561d 0x505be8 0x4f64b6 0x5089a9
tcmalloc: large alloc 2097152000 bytes == 0x6a8630000 @  0x7f0701396680 0x7f07013b7824 0x7f07013b7b8a 0x7f06e82d38e4 0x7f06e8298d03 0x7f06e9836af9 0x7f06e9830754 0x7f06e983079f 0x7f06e98307e5 0x7f06e9fcaf90 0x7f06eac15c91 0x7f06eac15ceb 0x7f06ea832d67 0x7f06eabdc25f 0x7f06ea87ad80 0x7f06f4dfaf12 0x4fc697 0x5089a9 0x4f2a14 0x4fcadf 0x4f56cd 0x505be8 0x4f619b 0x4f3851 0x4f561d 0x505be8 0x4f64b6 0x5089a9 0x4efb19 0x507eae 0x508858
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.93it/s]
Some weights of the model checkpoint at adapters_merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at adapters_merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tcmalloc: large alloc 2097152000 bytes == 0x6a8630000 @  0x7f0701396680 0x7f07013b7824 0x7f07013b7b8a 0x7f06e82d38e4 0x7f06e8298d03 0x7f06e9836af9 0x7f06e9830754 0x7f06e983079f 0x7f06e98307e5 0x7f06e9fcaf90 0x7f06eac15c91 0x7f06eac15ceb 0x7f06ea832d67 0x7f06eabdc25f 0x7f06ea87ad80 0x7f06f4dfaf12 0x4fc697 0x5089a9 0x4f2a14 0x4f561d 0x505be8 0x4f619b 0x4f40b0 0x4f561d 0x505be8 0x4f619b 0x4f434a 0x4f561d 0x505be8 0x4f64b6 0x5089a9
tcmalloc: large alloc 2097152000 bytes == 0x8fde30000 @  0x7f0701396680 0x7f07013b7824 0x7f07013b7b8a 0x7f06e82d38e4 0x7f06e8298d03 0x7f06e9836af9 0x7f06e9830754 0x7f06e983079f 0x7f06e98307e5 0x7f06e9fcaf90 0x7f06eac15c91 0x7f06eac15ceb 0x7f06ea832d67 0x7f06eabdc25f 0x7f06ea87ad80 0x7f06f4dfaf12 0x4fc697 0x5089a9 0x4f2a14 0x4fcadf 0x4f56cd 0x505be8 0x4f619b 0x4f3851 0x4f561d 0x505be8 0x4f64b6 0x5089a9 0x4efb19 0x507eae 0x508858
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.79it/s]
Inference with base model: 


Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein

I am
Inference with trained model: 


Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa

As you can see I also experience that kind of warning when loading merged model.

@shub-kris
Copy link
Contributor

@PawKanarek I am now able to replicate the error/warning you get, earlier I didn't get.

When I try debugging, I encountered this error when running with fsdp only. I am trying to look into what is not working, if it's the saving or something else.

Can you please re-run the script, with these commented

            #dataloader_drop_last = True,  # Required for SPMD.
            #fsdp="full_shard",
            #fsdp_config=fsdp_config,

and reduce the batch size according to your TPU and post the results here again .

cc @amyeroberts

@shub-kris shub-kris reopened this Mar 19, 2024
@PawKanarek
Copy link
Author

@shub-kris with commented-out FSDP and reduced batch_size=1 i could finally spot a really fine-tuned model without a warnings.

output (click arrow to expand)
(v_xla) raix@t1v-n-3a1a9ef8-w-0:~/minefinetune$  cd /home/raix/minefinetune ; /usr/bin/env /home/raix/miniconda3/envs/v_xla/bin/python /home/raix/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 59669 -- /home/raix/minefinetune/server/t.py 
torch.__version__='2.3.0.dev20240307'
torch_xla.__version__='2.3.0+git46e2230'
peft.__version__='0.9.0'
trl.__version__='0.7.12.dev0'
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.76it/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710880764.001550 1317987 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710880764.001618 1317987 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710880764.001631 1317987 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
  warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/raix/trl/trl/trainer/sft_trainer.py:316: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
  0%|                                                                                                       | 0/102 [00:00<?, ?it/s]/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1709797140173/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
{'loss': 4.5565, 'grad_norm': 5.90625, 'learning_rate': 0.00029705882352941177, 'epoch': 0.01}                                      
{'loss': 3.6308, 'grad_norm': 3.859375, 'learning_rate': 0.0002941176470588235, 'epoch': 0.02}                                      
{'loss': 3.48, 'grad_norm': 14.375, 'learning_rate': 0.00029117647058823524, 'epoch': 0.03}                                         
{'loss': 4.0616, 'grad_norm': 4.6875, 'learning_rate': 0.00028823529411764703, 'epoch': 0.04}                                       
{'loss': 3.2878, 'grad_norm': 3.1875, 'learning_rate': 0.00028529411764705877, 'epoch': 0.05}                                       
{'loss': 2.9565, 'grad_norm': 3.453125, 'learning_rate': 0.00028235294117647056, 'epoch': 0.06}                                     
{'loss': 2.7303, 'grad_norm': 2.53125, 'learning_rate': 0.00027941176470588236, 'epoch': 0.07}                                      
{'loss': 2.8276, 'grad_norm': 2.296875, 'learning_rate': 0.0002764705882352941, 'epoch': 0.08}                                      
{'loss': 2.7869, 'grad_norm': 2.5625, 'learning_rate': 0.00027352941176470583, 'epoch': 0.09}                                       
{'loss': 2.5918, 'grad_norm': 3.046875, 'learning_rate': 0.0002705882352941176, 'epoch': 0.1}                                       
{'loss': 2.5682, 'grad_norm': 2.703125, 'learning_rate': 0.0002676470588235294, 'epoch': 0.11}                                      
{'loss': 2.5969, 'grad_norm': 2.34375, 'learning_rate': 0.00026470588235294115, 'epoch': 0.12}                                      
{'loss': 2.4699, 'grad_norm': 13.8125, 'learning_rate': 0.00026176470588235295, 'epoch': 0.13}                                      
{'loss': 2.3431, 'grad_norm': 6.03125, 'learning_rate': 0.0002588235294117647, 'epoch': 0.14}                                       
{'loss': 2.4726, 'grad_norm': 5.625, 'learning_rate': 0.0002558823529411764, 'epoch': 0.15}                                         
{'loss': 2.4611, 'grad_norm': 31.0, 'learning_rate': 0.0002529411764705882, 'epoch': 0.16}                                          
{'loss': 2.2907, 'grad_norm': 2.78125, 'learning_rate': 0.00025, 'epoch': 0.17}                                                     
{'loss': 2.3958, 'grad_norm': 4.5625, 'learning_rate': 0.00024705882352941174, 'epoch': 0.18}                                       
{'loss': 2.3724, 'grad_norm': 1.8125, 'learning_rate': 0.0002441176470588235, 'epoch': 0.19}                                        
{'loss': 2.0032, 'grad_norm': 5.3125, 'learning_rate': 0.00024117647058823527, 'epoch': 0.2}                                        
{'loss': 2.599, 'grad_norm': 6.5625, 'learning_rate': 0.000238235294117647, 'epoch': 0.21}                                          
{'loss': 2.1205, 'grad_norm': 2.703125, 'learning_rate': 0.0002352941176470588, 'epoch': 0.22}                                      
{'loss': 1.691, 'grad_norm': 2.203125, 'learning_rate': 0.00023235294117647057, 'epoch': 0.23}                                      
{'loss': 2.1085, 'grad_norm': 1.890625, 'learning_rate': 0.0002294117647058823, 'epoch': 0.24}                                      
{'loss': 2.4238, 'grad_norm': 11.25, 'learning_rate': 0.0002264705882352941, 'epoch': 0.25}                                         
{'loss': 2.3273, 'grad_norm': 2.171875, 'learning_rate': 0.00022352941176470586, 'epoch': 0.25}                                     
{'loss': 1.8184, 'grad_norm': 1.796875, 'learning_rate': 0.00022058823529411765, 'epoch': 0.26}                                     
{'loss': 1.76, 'grad_norm': 1.515625, 'learning_rate': 0.0002176470588235294, 'epoch': 0.27}                                        
{'loss': 2.8188, 'grad_norm': 1.3359375, 'learning_rate': 0.00021470588235294116, 'epoch': 0.28}                                    
{'loss': 1.885, 'grad_norm': 1.53125, 'learning_rate': 0.00021176470588235295, 'epoch': 0.29}                                       
{'loss': 2.3718, 'grad_norm': 1.921875, 'learning_rate': 0.0002088235294117647, 'epoch': 0.3}                                       
{'loss': 2.3778, 'grad_norm': 3.328125, 'learning_rate': 0.00020588235294117645, 'epoch': 0.31}                                     
{'loss': 2.4162, 'grad_norm': 1.4453125, 'learning_rate': 0.00020294117647058822, 'epoch': 0.32}                                    
{'loss': 2.2694, 'grad_norm': 12.4375, 'learning_rate': 0.00019999999999999998, 'epoch': 0.33}                                      
{'loss': 1.2874, 'grad_norm': 5.46875, 'learning_rate': 0.00019705882352941175, 'epoch': 0.34}                                      
{'loss': 2.2275, 'grad_norm': 2.4375, 'learning_rate': 0.0001941176470588235, 'epoch': 0.35}                                        
{'loss': 1.4792, 'grad_norm': 1.40625, 'learning_rate': 0.00019117647058823528, 'epoch': 0.36}                                      
{'loss': 1.3559, 'grad_norm': 1.875, 'learning_rate': 0.00018823529411764704, 'epoch': 0.37}                                        
{'loss': 1.9698, 'grad_norm': 1.4609375, 'learning_rate': 0.0001852941176470588, 'epoch': 0.38}                                     
{'loss': 1.8739, 'grad_norm': 1.8125, 'learning_rate': 0.00018235294117647055, 'epoch': 0.39}                                       
{'loss': 0.6814, 'grad_norm': 1.078125, 'learning_rate': 0.00017941176470588234, 'epoch': 0.4}                                      
{'loss': 2.2777, 'grad_norm': 1.734375, 'learning_rate': 0.0001764705882352941, 'epoch': 0.41}                                      
{'loss': 2.4052, 'grad_norm': 1.125, 'learning_rate': 0.0001735294117647059, 'epoch': 0.42}                                         
{'loss': 1.1264, 'grad_norm': 1.265625, 'learning_rate': 0.00017058823529411763, 'epoch': 0.43}                                     
{'loss': 1.4286, 'grad_norm': 1.234375, 'learning_rate': 0.0001676470588235294, 'epoch': 0.44}                                      
{'loss': 2.0239, 'grad_norm': 1.203125, 'learning_rate': 0.0001647058823529412, 'epoch': 0.45}                                      
{'loss': 2.4702, 'grad_norm': 0.9609375, 'learning_rate': 0.00016176470588235293, 'epoch': 0.46}                                    
{'loss': 1.6212, 'grad_norm': 1.0703125, 'learning_rate': 0.0001588235294117647, 'epoch': 0.47}                                     
{'loss': 1.4587, 'grad_norm': 4.5, 'learning_rate': 0.00015588235294117646, 'epoch': 0.48}                                          
{'loss': 2.3347, 'grad_norm': 1.4296875, 'learning_rate': 0.00015294117647058822, 'epoch': 0.49}                                    
{'loss': 1.7701, 'grad_norm': 1.2734375, 'learning_rate': 0.00015, 'epoch': 0.5}                                                    
{'loss': 2.4789, 'grad_norm': 1.296875, 'learning_rate': 0.00014705882352941175, 'epoch': 0.51}                                     
{'loss': 2.3662, 'grad_norm': 3.484375, 'learning_rate': 0.00014411764705882352, 'epoch': 0.52}                                     
{'loss': 2.2018, 'grad_norm': 1.59375, 'learning_rate': 0.00014117647058823528, 'epoch': 0.53}                                      
{'loss': 2.2774, 'grad_norm': 1.1796875, 'learning_rate': 0.00013823529411764705, 'epoch': 0.54}                                    
{'loss': 1.691, 'grad_norm': 1.265625, 'learning_rate': 0.0001352941176470588, 'epoch': 0.55}                                       
{'loss': 2.4592, 'grad_norm': 1.0625, 'learning_rate': 0.00013235294117647058, 'epoch': 0.56}                                       
{'loss': 2.1323, 'grad_norm': 1.1875, 'learning_rate': 0.00012941176470588234, 'epoch': 0.57}                                       
{'loss': 2.14, 'grad_norm': 1.171875, 'learning_rate': 0.0001264705882352941, 'epoch': 0.58}                                        
{'loss': 2.0911, 'grad_norm': 2.109375, 'learning_rate': 0.00012352941176470587, 'epoch': 0.59}                                     
{'loss': 2.3724, 'grad_norm': 1.171875, 'learning_rate': 0.00012058823529411764, 'epoch': 0.6}                                      
{'loss': 1.9369, 'grad_norm': 1.0859375, 'learning_rate': 0.0001176470588235294, 'epoch': 0.61}                                     
{'loss': 2.1488, 'grad_norm': 1.5078125, 'learning_rate': 0.00011470588235294115, 'epoch': 0.62}                                    
{'loss': 2.5139, 'grad_norm': 1.0546875, 'learning_rate': 0.00011176470588235293, 'epoch': 0.63}                                    
{'loss': 2.2037, 'grad_norm': 1.328125, 'learning_rate': 0.0001088235294117647, 'epoch': 0.64}                                      
{'loss': 1.4069, 'grad_norm': 1.65625, 'learning_rate': 0.00010588235294117647, 'epoch': 0.65}                                      
{'loss': 1.6892, 'grad_norm': 1.0625, 'learning_rate': 0.00010294117647058823, 'epoch': 0.66}                                       
{'loss': 2.6367, 'grad_norm': 1.21875, 'learning_rate': 9.999999999999999e-05, 'epoch': 0.67}                                       
{'loss': 2.0439, 'grad_norm': 7.0625, 'learning_rate': 9.705882352941176e-05, 'epoch': 0.68}                                        
{'loss': 2.0848, 'grad_norm': 1.4375, 'learning_rate': 9.411764705882352e-05, 'epoch': 0.69}                                        
{'loss': 2.3307, 'grad_norm': 1.0, 'learning_rate': 9.117647058823527e-05, 'epoch': 0.7}                                            
{'loss': 2.4189, 'grad_norm': 0.95703125, 'learning_rate': 8.823529411764705e-05, 'epoch': 0.71}                                    
{'loss': 2.4486, 'grad_norm': 1.4765625, 'learning_rate': 8.529411764705882e-05, 'epoch': 0.72}                                     
{'loss': 2.0123, 'grad_norm': 1.2421875, 'learning_rate': 8.23529411764706e-05, 'epoch': 0.73}                                      
{'loss': 1.3727, 'grad_norm': 1.15625, 'learning_rate': 7.941176470588235e-05, 'epoch': 0.74}                                       
{'loss': 1.6614, 'grad_norm': 1.59375, 'learning_rate': 7.647058823529411e-05, 'epoch': 0.75}                                       
{'loss': 1.9256, 'grad_norm': 1.171875, 'learning_rate': 7.352941176470588e-05, 'epoch': 0.75}                                      
{'loss': 2.3132, 'grad_norm': 1.5859375, 'learning_rate': 7.058823529411764e-05, 'epoch': 0.76}                                     
{'loss': 2.4786, 'grad_norm': 0.91796875, 'learning_rate': 6.76470588235294e-05, 'epoch': 0.77}                                     
{'loss': 2.0456, 'grad_norm': 1.0625, 'learning_rate': 6.470588235294117e-05, 'epoch': 0.78}                                        
{'loss': 2.0947, 'grad_norm': 9.125, 'learning_rate': 6.176470588235294e-05, 'epoch': 0.79}                                         
{'loss': 2.0991, 'grad_norm': 0.93359375, 'learning_rate': 5.88235294117647e-05, 'epoch': 0.8}                                      
{'loss': 2.1226, 'grad_norm': 1.421875, 'learning_rate': 5.5882352941176466e-05, 'epoch': 0.81}                                     
{'loss': 1.5752, 'grad_norm': 1.4296875, 'learning_rate': 5.294117647058824e-05, 'epoch': 0.82}                                     
{'loss': 1.821, 'grad_norm': 0.96875, 'learning_rate': 4.9999999999999996e-05, 'epoch': 0.83}                                       
{'loss': 2.1587, 'grad_norm': 1.59375, 'learning_rate': 4.705882352941176e-05, 'epoch': 0.84}                                       
{'loss': 1.962, 'grad_norm': 4.71875, 'learning_rate': 4.4117647058823526e-05, 'epoch': 0.85}                                       
{'loss': 2.0884, 'grad_norm': 2.1875, 'learning_rate': 4.11764705882353e-05, 'epoch': 0.86}                                         
{'loss': 2.0188, 'grad_norm': 1.1328125, 'learning_rate': 3.8235294117647055e-05, 'epoch': 0.87}                                    
{'loss': 2.6892, 'grad_norm': 3.96875, 'learning_rate': 3.529411764705882e-05, 'epoch': 0.88}                                       
{'loss': 1.6855, 'grad_norm': 0.95703125, 'learning_rate': 3.2352941176470585e-05, 'epoch': 0.89}                                   
{'loss': 2.0799, 'grad_norm': 2.953125, 'learning_rate': 2.941176470588235e-05, 'epoch': 0.9}                                       
{'loss': 1.1163, 'grad_norm': 1.1796875, 'learning_rate': 2.647058823529412e-05, 'epoch': 0.91}                                     
{'loss': 2.0968, 'grad_norm': 1.078125, 'learning_rate': 2.352941176470588e-05, 'epoch': 0.92}                                      
{'loss': 1.38, 'grad_norm': 0.92578125, 'learning_rate': 2.058823529411765e-05, 'epoch': 0.93}                                      
{'loss': 1.9148, 'grad_norm': 1.203125, 'learning_rate': 1.764705882352941e-05, 'epoch': 0.94}                                      
{'loss': 1.1043, 'grad_norm': 1.2578125, 'learning_rate': 1.4705882352941175e-05, 'epoch': 0.95}                                    
{'loss': 2.1801, 'grad_norm': 0.953125, 'learning_rate': 1.176470588235294e-05, 'epoch': 0.96}                                      
{'loss': 2.2353, 'grad_norm': 1.3984375, 'learning_rate': 8.823529411764705e-06, 'epoch': 0.97}                                     
{'loss': 2.0111, 'grad_norm': 1.25, 'learning_rate': 5.88235294117647e-06, 'epoch': 0.98}                                           
{'loss': 1.6439, 'grad_norm': 1.171875, 'learning_rate': 2.941176470588235e-06, 'epoch': 0.99}                                      
{'loss': 1.7974, 'grad_norm': 1.046875, 'learning_rate': 0.0, 'epoch': 1.0}                                                         
{'train_runtime': 268.2876, 'train_samples_per_second': 0.38, 'train_steps_per_second': 0.38, 'train_loss': 2.1707937787560856, 'epoch': 1.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 102/102 [04:28<00:00,  2.63s/it]
tcmalloc: large alloc 2097152000 bytes == 0x1c126c000 @  0x7f78a7128680 0x7f78a7149824 0x7f78a7149b8a 0x7f788e0658e4 0x7f788e02ad03 0x7f788f5c8af9 0x7f788f5c2754 0x7f788f5c279f 0x7f788f5c27e5 0x7f788fd5cf90 0x7f78909a7c91 0x7f78909a7ceb 0x7f78905c4d67 0x7f789096e25f 0x7f789060cd80 0x7f789ab8cf12 0x4fc697 0x5089a9 0x4f2a14 0x4f561d 0x505be8 0x4f619b 0x4f40b0 0x4f561d 0x505be8 0x4f619b 0x4f434a 0x4f561d 0x505be8 0x4f64b6 0x5089a9
tcmalloc: large alloc 2097152000 bytes == 0x34b80c000 @  0x7f78a7128680 0x7f78a7149824 0x7f78a7149b8a 0x7f788e0658e4 0x7f788e02ad03 0x7f788f5c8af9 0x7f788f5c2754 0x7f788f5c279f 0x7f788f5c27e5 0x7f788fd5cf90 0x7f78909a7c91 0x7f78909a7ceb 0x7f78905c4d67 0x7f789096e25f 0x7f789060cd80 0x7f789ab8cf12 0x4fc697 0x5089a9 0x4f2a14 0x4fcadf 0x4f56cd 0x505be8 0x4f619b 0x4f3851 0x4f561d 0x505be8 0x4f64b6 0x5089a9 0x4efb19 0x507eae 0x508858
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.63it/s]
tcmalloc: large alloc 2097152000 bytes == 0x34b80c000 @  0x7f78a7128680 0x7f78a7149824 0x7f78a7149b8a 0x7f788e0658e4 0x7f788e02ad03 0x7f788f5c8af9 0x7f788f5c2754 0x7f788f5c279f 0x7f788f5c27e5 0x7f788fd5cf90 0x7f78909a7c91 0x7f78909a7ceb 0x7f78905c4d67 0x7f789096e25f 0x7f789060cd80 0x7f789ab8cf12 0x4fc697 0x5089a9 0x4f2a14 0x4f561d 0x505be8 0x4f619b 0x4f40b0 0x4f561d 0x505be8 0x4f619b 0x4f434a 0x4f561d 0x505be8 0x4f64b6 0x5089a9
tcmalloc: large alloc 2097152000 bytes == 0x741f0c000 @  0x7f78a7128680 0x7f78a7149824 0x7f78a7149b8a 0x7f788e0658e4 0x7f788e02ad03 0x7f788f5c8af9 0x7f788f5c2754 0x7f788f5c279f 0x7f788f5c27e5 0x7f788fd5cf90 0x7f78909a7c91 0x7f78909a7ceb 0x7f78905c4d67 0x7f789096e25f 0x7f789060cd80 0x7f789ab8cf12 0x4fc697 0x5089a9 0x4f2a14 0x4fcadf 0x4f56cd 0x505be8 0x4f619b 0x4f3851 0x4f561d 0x505be8 0x4f64b6 0x5089a9 0x4efb19 0x507eae 0x508858
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.47it/s]
Inference with base model: 


Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein

I am
Inference with trained model: 


Quote: Imagination is more important than knowledge. - Albert Einstein

@shub-kris
Copy link
Contributor

@PawKanarek thank you for the confirmation. I need to now look into what's going wrong when we use FSDP if it's the saving or something else?

cc @amyeroberts

@shub-kris
Copy link
Contributor

@alanwaketan can you please take a look into it.

@alanwaketan
Copy link
Contributor

I think the issue is for the FSDP wrapped model, we need to unwrap the model before saving it. I have given instructions to @shub-kris for fixing the unwrap logic in HF.

If things don't work out in HF, I will provide a utility in torch-xla to unwrap the model.

@shub-kris
Copy link
Contributor

shub-kris commented Mar 21, 2024

@PawKanarek @zorrofox can you now try with the PR: #29780

For me everything works perfectly now.

@zorrofox
Copy link

@shub-kris This time the merged model loading warning is disappeared but the inference result is not very good.

train output torch.__version__='2.3.0.dev20240312+cu121' torch_xla.__version__='2.3.0+git97acc14' peft.__version__='0.9.0' trl.__version__='0.7.11' Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.82it/s] WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1711089816.717824 13174 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so I0000 00:00:1711089816.717896 13174 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1711089816.717907 13174 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40. /home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py:104: UserWarning: `devkind` argument is deprecated and will be removed in a future release. warnings.warn("`devkind` argument is deprecated and will be removed in a " /home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:294: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code. warnings.warn( /home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True) warnings.warn( 0%| | 0/30 [00:00 warnings.warn("For backward hooks to be called," /home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at ../torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 3%|███▌ | 1/30 [01:31<44:15, 91.58s/it]{'loss': 4.273, 'grad_norm': 3.734375, 'learning_rate': 0.00029, 'epoch': 0.33} {'loss': 4.1188, 'grad_norm': 4.09375, 'learning_rate': 0.00028, 'epoch': 0.67} {'loss': 3.7833, 'grad_norm': 5.21875, 'learning_rate': 0.00027, 'epoch': 1.0} {'loss': 3.3963, 'grad_norm': 4.375, 'learning_rate': 0.00026, 'epoch': 1.33} {'loss': 3.2293, 'grad_norm': 2.75, 'learning_rate': 0.00025, 'epoch': 1.67} {'loss': 2.9096, 'grad_norm': 1.9296875, 'learning_rate': 0.00023999999999999998, 'epoch': 2.0} {'loss': 2.7546, 'grad_norm': 1.953125, 'learning_rate': 0.00023, 'epoch': 2.33} {'loss': 2.7717, 'grad_norm': 2.453125, 'learning_rate': 0.00021999999999999995, 'epoch': 2.67} {'loss': 2.6428, 'grad_norm': 2.921875, 'learning_rate': 0.00020999999999999998, 'epoch': 3.0} {'loss': 2.6672, 'grad_norm': 1.6640625, 'learning_rate': 0.00019999999999999998, 'epoch': 3.33} {'loss': 2.5239, 'grad_norm': 4.4375, 'learning_rate': 0.00018999999999999998, 'epoch': 3.67} {'loss': 2.4252, 'grad_norm': 1.578125, 'learning_rate': 0.00017999999999999998, 'epoch': 4.0} {'loss': 2.455, 'grad_norm': 2.265625, 'learning_rate': 0.00016999999999999999, 'epoch': 4.33} {'loss': 2.5093, 'grad_norm': 1.4921875, 'learning_rate': 0.00015999999999999999, 'epoch': 4.67} {'loss': 2.2936, 'grad_norm': 2.0, 'learning_rate': 0.00015, 'epoch': 5.0} {'loss': 2.3667, 'grad_norm': 2.1875, 'learning_rate': 0.00014, 'epoch': 5.33} {'loss': 2.3081, 'grad_norm': 2.21875, 'learning_rate': 0.00013, 'epoch': 5.67} {'loss': 2.2664, 'grad_norm': 3.59375, 'learning_rate': 0.00011999999999999999, 'epoch': 6.0} {'loss': 2.2403, 'grad_norm': 6.5625, 'learning_rate': 0.00010999999999999998, 'epoch': 6.33} {'loss': 2.269, 'grad_norm': 5.28125, 'learning_rate': 9.999999999999999e-05, 'epoch': 6.67} {'loss': 2.2112, 'grad_norm': 1.15625, 'learning_rate': 8.999999999999999e-05, 'epoch': 7.0} {'loss': 2.2353, 'grad_norm': 1.2265625, 'learning_rate': 7.999999999999999e-05, 'epoch': 7.33} {'loss': 2.15, 'grad_norm': 1.3984375, 'learning_rate': 7e-05, 'epoch': 7.67} {'loss': 2.2592, 'grad_norm': 1.359375, 'learning_rate': 5.9999999999999995e-05, 'epoch': 8.0} {'loss': 2.185, 'grad_norm': 0.8359375, 'learning_rate': 4.9999999999999996e-05, 'epoch': 8.33} {'loss': 2.1976, 'grad_norm': 0.75390625, 'learning_rate': 3.9999999999999996e-05, 'epoch': 8.67} {'loss': 2.1421, 'grad_norm': 0.8203125, 'learning_rate': 2.9999999999999997e-05, 'epoch': 9.0} {'loss': 2.2024, 'grad_norm': 0.81640625, 'learning_rate': 1.9999999999999998e-05, 'epoch': 9.33} {'loss': 2.0441, 'grad_norm': 0.90625, 'learning_rate': 9.999999999999999e-06, 'epoch': 9.67} {'loss': 2.1902, 'grad_norm': 0.70703125, 'learning_rate': 0.0, 'epoch': 10.0} {'train_runtime': 574.9741, 'train_samples_per_second': 1.67, 'train_steps_per_second': 0.052, 'train_loss': 2.600708842277527, 'epoch': 10.0} 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [09:34<00:00, 19.17s/it] Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.70it/s] Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.74it/s] Inference with base model:

Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein

I am
Inference with trained model:

Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. Imagination is the only way to

@shub-kris
Copy link
Contributor

shub-kris commented Mar 22, 2024

@zorrofox try training longer as your losses are still high. I don't remember the exact hyperparameters I tried but I was able to get decent results.

Thanks for confirming that the issue is resolved regarding saving and reloading the weights.

@PawKanarek
Copy link
Author

That's great news @shub-kris! Thank you for the quick fix and hard work! I will post update when I'm done with my current trainings (because my workaorund still works and i don't want to break my pipeline). Could you provide the minimal pseudo-code with correct pattern for unloading and merging LoRA adapter as standalone model?
WIll this be correct?

trainer = SFTTrainer(...)
trainer.train()
merged_model = trainer.model.merge_and_unload() # merge LORA with base model
merged_model.to("cpu")
merged_model.save_pretrained("adapters_merged")

Is this OK? Or do i need also make trainer.save_model() after training?

@shub-kris
Copy link
Contributor

shub-kris commented Mar 25, 2024

@PawKanarek in the training script: I will recommend to do the training and save the model.

  trainer.train()
  # saving final model
  trainer.save_model()

Merging can be done in a separate script to avoid any kind of TPU or FSDP wrapper issues. I follow as mentioned here: https://huggingface.co/docs/trl/en/use_model#use-adapters-peft

import torch
import peft
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel


base_model_name = "google/gemma-2b" 
model = AutoModelForCausalLM.from_pretrained(base_model_name)

adapter_model_name = "fsdp_output"
print(f"Adapter model is {adapter_model_name}")

# Load trained peft model
trained_peft_model = PeftModel.from_pretrained(model, adapter_model_name)

merged_model = trained_peft_model.merge_and_unload() # merge LORA with base model
merged_model.save_pretrained("merged_model")

@PawKanarek
Copy link
Author

Fix from and saving with given pattern works flawlessly. Thank you @shub-kris 👨‍💻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants