Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in SFTTrainer #1837

Closed
jjc10 opened this issue Jul 16, 2024 · 5 comments · Fixed by #1841
Closed

Possible bug in SFTTrainer #1837

jjc10 opened this issue Jul 16, 2024 · 5 comments · Fixed by #1841

Comments

@jjc10
Copy link

jjc10 commented Jul 16, 2024

Hi folks,

I am running into an issue with the SFTTrainer regarding the neftune_noise_alpha parameter.

If I leave this argument empty and let it default to None, there seems to be an error in the code flow on line 307 of the SFTTrainer. Namely (with added comments):

self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") # This evaluates to true when the default None is used
if neftune_noise_alpha is not None and self._trainer_supports_neftune: # this gets skipped because of neftune_noise_alpha is None
            args.neftune_noise_alpha = neftune_noise_alpha
            warnings.warn(
                "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
            )
            # self.neftune_noise_alpha is done at Trainer level
elif not self._trainer_supports_neftune: # But this never gets executed and thus self.neftune_noise_alpha is never set
            self.neftune_noise_alpha = neftune_noise_alpha

Later on in the train() at line 448, the error AttributeError: 'SFTTrainer' object has no attribute 'neftune_noise_alpha'

Possible fix:
I believe one fix would be to change line 307 from hasattr(args, "neftune_noise_alpha") to hasattr(args, "neftune_noise_alpha") and neftune_noise_alpha is not None

Additional info:

  • OS: Ubuntu 22.04.4
  • TRL 0.9.6
  • Torch 1.13.0+cu116
  • Transformers 4.34.1

Thanks and cheers!

@kashif
Copy link
Collaborator

kashif commented Jul 17, 2024

@jjc10 would you like to send a PR?

@kashif
Copy link
Collaborator

kashif commented Jul 17, 2024

@jjc10 i fixed it via the above PR

@jjc10
Copy link
Author

jjc10 commented Jul 19, 2024

Sorry just saw your reply @kashif .
Didn't get a chance to open a PR myself this week. I'm reviewing yours!

@dumeixiang
Copy link

Hi ,
I've bump into the same issue, tried neftune_noise_alpha=None, 0.1, both trigers the same error: AttributeError: 'SFTTrainer' object has no attribute 'neftune_noise_alpha'
trainer = SFTTrainer( model=model, train_dataset=dataset, neftune_noise_alpha=None, peft_config=peft_config, max_seq_length=max_seq_length, tokenizer=tokenizer, packing=False, formatting_func=format_instruction, args=args, )

@jjc10
Copy link
Author

jjc10 commented Aug 6, 2024

Hi @dumeixiang,
I will pick up the PR next week to submit a fix.
A temporary workaround is to first initialize your trainer without passing neftune_noise_alpha and right after initializing it (that is, before calling sft_trainer.train() or sft_trainer.evaluate()), you manually set sft_trainer.neftune_noise_alpha = None.

Example:

trainer = SFTTrainer(
        model=my_model,
        args=my_config,
        train_dataset=tokenized_dataset['train'],
        eval_dataset=tokenized_dataset['validation'],
        tokenizer=tokenizer,
        data_collator=collator,
    )
    
trainer.neftune_noise_alpha = None
trainer.train()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants