Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SFTTrainer support #682

Merged
merged 14 commits into from
Sep 5, 2024
24 changes: 16 additions & 8 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class PeftConfig:
transformers_get_optimizer_cls_and_kwargs = Trainer.get_optimizer_cls_and_kwargs


class AugmentTrainerForNeuronMixin:
class _TrainerForNeuron:
def __init__(self, *args, **kwargs):
if not isinstance(self, Trainer):
raise TypeError(f"{self.__class__.__name__} can only be mixed with Trainer subclasses.")
Expand Down Expand Up @@ -492,7 +492,11 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
tr_loss.zero_()

def log_closure(self, reduced_tr_loss, grad_norm):
if is_main_worker_for_metrics():
# We need to check that self.state.global_step > self._globalstep_last_logged because if two
# closures are added in a row (which can happen at the end of the training), then it will fail the
# second time because at this point we will have:
# self.state.global_step = self._globalstep_last_logged
if is_main_worker_for_metrics() and self.state.global_step > self._globalstep_last_logged:
logs: Dict[str, float] = {}
tr_loss_scalar = reduced_tr_loss.to("cpu").item()

Expand Down Expand Up @@ -1493,19 +1497,24 @@ def save_state(self):
return super().save_state()


class NeuronTrainer(AugmentTrainerForNeuronMixin, Trainer):
class NeuronTrainer(_TrainerForNeuron, Trainer):
"""
Trainer that is suited for performing training on AWS Tranium instances.
"""


class Seq2SeqNeuronTrainer(AugmentTrainerForNeuronMixin, Seq2SeqTrainer):
class Seq2SeqNeuronTrainer(_TrainerForNeuron, Seq2SeqTrainer):
"""
Seq2SeqTrainer that is suited for performing training on AWS Tranium instances.
"""


class NeuronSFTTrainer(AugmentTrainerForNeuronMixin, SFTTrainer):
class _SFTTrainerTrainerInit(SFTTrainer):
def __init__(self, *args, **kwargs):
return Trainer.__init__(self, *args, **kwargs)


class NeuronSFTTrainer(_TrainerForNeuron, _SFTTrainerTrainerInit):
def __init__(
self,
model: Optional[Union[PreTrainedModel, torch.nn.Module, str]] = None,
Expand Down Expand Up @@ -1731,8 +1740,7 @@ def make_inputs_require_grad(module, input, output):
"overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code."
)

AugmentTrainerForNeuronMixin.__init__(
self,
super().__init__(
model=model,
args=args,
data_collator=data_collator,
Expand Down Expand Up @@ -1764,7 +1772,7 @@ def make_inputs_require_grad(module, input, output):
if callback.__class__.__name__ == "PrinterCallback":
self.callback_handler.pop_callback(callback)

@wraps(AugmentTrainerForNeuronMixin.train)
@wraps(_TrainerForNeuron.train)
def train(self, *args, **kwargs):
# Activate neftune right before training.
if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
Expand Down
71 changes: 69 additions & 2 deletions tests/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
AutoModelForSequenceClassification,
)

from optimum.neuron import NeuronTrainer, NeuronTrainingArguments
from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainer, NeuronTrainingArguments
from optimum.neuron.distributed.utils import MODEL_PARALLEL_SHARDS_DIR_NAME
from optimum.neuron.utils import is_neuronx_distributed_available
from optimum.neuron.utils.cache_utils import (
Expand Down Expand Up @@ -300,7 +300,7 @@ def create_training_args(output_dir, resume_from_checkpoint=None, max_steps=max_
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=eval_batch_size,
max_steps=max_steps,
logging_steps=1,
logging_steps=2,
save_steps=5,
do_eval=do_eval,
output_dir=output_dir,
Expand Down Expand Up @@ -396,3 +396,70 @@ def preprocess_function(examples):

trainer.train(resume_from_checkpoint=True)
trainer.evaluate()


@is_trainium_test
class TestNeuronSFTTrainer(DistributedTest):
@pytest.fixture(
scope="class",
params=[[2, 1, 1], [2, 2, 1]],
ids=["dp=2", "tp=2"],
)
def parallel_sizes(self, request):
return request.param

def _test_sft_trainer(self, parallel_sizes, tmpdir, packing):
_, tp_size, pp_size = parallel_sizes

output_dir = Path(tmpdir)

dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
# dataset = dataset.select(range(1000))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove commented line


def format_dolly(sample):
instruction = f"### Instruction\n{sample['instruction']}"
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
response = f"### Answer\n{sample['response']}"
# join all the parts together
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
if packing:
return prompt
return [prompt]

tokenizer, model = get_tokenizer_and_tiny_llama_model()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # to prevent warnings

args = NeuronTrainingArguments(
output_dir=output_dir,
do_train=True,
max_steps=20,
per_device_train_batch_size=1,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
logging_steps=1,
)
args = args.to_dict()
sft_config = NeuronSFTConfig(
max_seq_length=512,
packing=packing,
dataset_num_proc=1,
**args,
)

# Create Trainer instance
trainer = NeuronSFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=format_dolly,
args=sft_config,
)

trainer.train()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we verify that the loss goes down or something ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a tiny random model. The SFTTrainer does not anything related to the loss anyways. It's just Trainer with dataset preparation abilities.


def test_without_packing(self, parallel_sizes, tmpdir):
return self._test_sft_trainer(parallel_sizes, tmpdir, False)

def test_with_packing(self, parallel_sizes, tmpdir):
return self._test_sft_trainer(parallel_sizes, tmpdir, True)
Loading