Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issues to be compatible with latest peft #359

Closed

Conversation

pacman100
Copy link

What does this PR do?

  1. Fixes the issues model.save_pretrained() produced a corrupted adapter_model.bin (only 443 B) with alpaca-lora huggingface/peft#286 and Bug with saving LoRA (adapter_model.bin) on latest peft from git huggingface/peft#317
  2. Adds Callback to be used with HF Trainer to make sure intermediate checkpoints are saving only peft adapters, solving model.save_pretrained() produced a corrupted adapter_model.bin (only 443 B) with alpaca-lora huggingface/peft#286 (comment)
  3. Adds a debug_mode arg to quickly test out the fine-tune script on tiny subset of dataset

@lksysML
Copy link

lksysML commented Apr 18, 2023

Your pull request isn't working. Crashed when it tries to save a checkpoint, was training on 8x RTX 3090.

@mcmonkey4eva
Copy link

@lksysML if you're training in 8bit, there's a separate bug with bitsandbytes latest version that causes a massive VRAM spike when saving files. Can bypass it by backdating: pip install bitsandbytes==0.37.2

@lksysML
Copy link

lksysML commented Apr 18, 2023

@lksysML if you're training in 8bit, there's a separate bug with bitsandbytes latest version that causes a massive VRAM spike when saving files. Can bypass it by backdating: pip install bitsandbytes==0.37.2

Yea, I ran into that issue yesterday and fixed it already. I think that today's run has something to do with peft, I rolled back to an old version and it didn't crash.
WandB report of crash with your pull request: https://wandb.ai/lksy/huggingface/runs/hxq051bg/log

Also, debug_mode doesn't work. I set it to True and it continued with the whole training set instead of the 1024 examples as it is supposed to.

@pacman100
Copy link
Author

Hello, this PR works fine for me, for debug mode you have to specify --debug_mode.
wandb: https://wandb.ai/smangrul/huggingface/runs/3cxt4cnc?workspace=user-smangrul

with bnb==0.37, the vram is stable
Screenshot 2023-04-20 at 4 29 25 PM

@ElleLeonne
Copy link
Contributor

ElleLeonne commented Apr 21, 2023

@pacman100

I'm a bit confused about the intended functionality here.

In that, when attempting to train a new adapter after loading the weights of the old one, I get bad results (mainly because of my small dataset size, which is a requirement for my implementation).

I was able to get good results by continuing the training of a previous adapter instead (and which was the functionality that broke).

Now, we save the adapter model in a new folder
image
image

We delete the pytorch_model.bin (which is the full weights, not something the OJ code ever needed, which makes sense. It was saving adapter_model.bin as pytorch_model.bin before)
image

But then the resume_from_checkpoint code is still searching for pytorch_model.bin, which can't exist.
image

I can move the adapter out of the folder and rename it, in an attempt to get the previous functionality, but then
image

It's not actually fixed. And the conflicting nature of this PR's code tells me that this functionality wasn't tested anyways.

I think it's pretty big issue, because the alternative is retraining the adapter in full every single time I get new data, which is actually pretty time and power consuming compared to the alternative.

@pacman100
Copy link
Author

@ElleLeonne

I did test and it does work

Continuing from the ckpt of the above run via:

python finetune.py     --base_model 'path/to/it'     --data_path 'yahma/alpaca-cleaned'     --output_dir './lora-alpaca-2' --debug_mode --num_epochs=3 --cutoff_len=512 --lora_target_modules="[q_proj,k_proj, v_proj, o_proj]" --lora_r=16 --micro_batch_size=8 --resume_from_checkpoint ./lora-alpaca/checkpoint-20/adapter_model/

results in it being correctly loaded:
Screenshot 2023-04-21 at 10 31 39 PM

In the above run, the eval loss at the end was 1.70. In the run here using above command: https://wandb.ai/smangrul/huggingface/runs/18ux1bhz?workspace=user-smangrul, the eval loss starts at 1.65 and ends at 1.48

@ElleLeonne
Copy link
Contributor

ElleLeonne commented Apr 21, 2023

@ElleLeonne

I did test and it does work

@pacman100

I realize this isn't your repo, so I appreciate you taking the time to investigate this and help improve our little project here.

Using his method will cause the code to perform the same functionality as loading the adapter, then starting a new training run. This is because this block
Capture
will never execute, now that we're removing the pytorch_model.bin. (previously, there was a binding script that would save the lora_adapter's state dict as pytorch_model.bin. It was the same file, just named differently.) I can attempt to preserve this functionality by moving the adapter to the main folder and renaming it, but then I encounter the previous error.

As a result,
Capture

the trainer will literally never try to load the state of the old training session. It will instantiate a new training session, using a new instance of warmup steps, new training loss calculations, and a new starting point for the adapter. It functions more akin to merging the weights with the model and starting a new training run, if I understand the process correctly. Even playing with the hyperparameters, dropping batch size, killing warmup steps, tweaking learning rate, it always causes the new adapter to overfit like mad due to my small dataset size.

Which I guess is just to say that this PR doesn't restore the old functionality of the code. It just cuts out half of what the old code used to do. It would be really nice to be able to load the trainer's state when continuing training again. Which I suppose it what broke in the first place.

@mcmonkey4eva
Copy link

mcmonkey4eva commented Apr 21, 2023

Er... the logic of the code there is:

  • if pytorch_model.bin exists, then:
    • Load the checkpoint file and use it
    • resume_from_checkpoint is True, telling trainer.Train to load the file again (erroneously loaded twice in pre-existing code)
  • if it does not exist, then:
    • set resume_from_checkpoint to False to tell trainer.train that it should not load any file
    • check if adapter_model.bin exists, if it does:
      • load the adapter_model file and use it
      • train based on that
    • else:
      • train a new model

the torch.load + set_peft_model_state_dict lines are what enable resuming a prior checkpoint. the resume_from_checkpoint being sent to trainer.Train is a pre-existing code mistake.

For reference, train's documentation for the resume_from_checkpoint arg:

resume_from_checkpoint (str or bool, optional):
If a str, local path to a saved checkpoint as saved by a previous instance of [Trainer]. If a
bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance
of [Trainer]. If present, training will resume from the model/optimizer/scheduler states loaded here.

The loading of adapter_model.bin (ie the correct file that is intended to be used) means resuming from prior saves (of adapter_model.bin) does in fact work.

@ElleLeonne
Copy link
Contributor

resume_from_checkpoint is True, telling trainer.Train to load the file again (erroneously loaded twice in pre-existing code)

I don't think this is true. There was a very clear marked difference in quality with the old, working version of resume_from_checkpoint, and the current version of resuming from just the adapter. It wasn't just a coding mistake.

The loading of adapter_model.bin (ie the correct file that is intended to be used) means resuming from prior saves (of adapter_model.bin) does in fact work.

I know it works. My point was that this is only half the required solution. I was able to get this solution working myself without the PR. The results are just bad.

#154 (comment)
See the original PR for the update in question, too.

@pacman100
Copy link
Author

@ElleLeonne,

this gist https://gist.github.com/pacman100/8e7a6eedabf34e1a88dd74a96c3b619f should exhibit the behaviour that you are looking for.

But, it doesn't make much sense to me, could you provide a concrete example of how the existing code with previous peft was being used and how this PR is failing to do that with some concrete metrics?

  1. due to linear decay, the final ckpt of existing lora ckpt will have learning rate close to 0 and so what you were earlier doing was essentially training with very low lr
  2. IF you were adding the new small dataset to the initial dataset to increase the number of steps and loading the previous ckpt for continuation, then it would have been randomly shuffled and the number of steps till the previous ckpt would have been skipped which most likely would contain many samples from the new small dataset.

Both the above points are a weird way of continuing training.

@ElleLeonne
Copy link
Contributor

I appreciate that. I'll play around with it and see how everything works out.

The idea was basically to introduce the new data as a new "epoch", which helped include the relevant data without needing to fully retrain the adapter every night (which takes about 7 hours for me). Then I can delay needing to spin the whole adapter retraining up for a longer period of time, while still showing noticeable improvements day-to-day. It also fixed a major problem where sometimes training would fail in the middle of the night, and I'd need to use the old adapter until I could fix the problem.

I would increase the epoch from (let's say) 5 to 10, so the learning rate decay would still exist, but not be 0.

liuhoward added a commit to liuhoward/alpaca-lora that referenced this pull request May 5, 2023
@Maxwell-Lyu
Copy link

Maxwell-Lyu commented May 18, 2023

I dont know if this is the RIGHT way, but this simple modification at L275 produces a adapter_model.bin with the right size:

- model.save_pretrained(output_dir)
+ model.save_pretrained(output_dir, state_dict=old_state_dict())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants