Skip to content

Commit

Permalink
SetFit partially up-to-date
Browse files Browse the repository at this point in the history
  • Loading branch information
AjayP13 committed Jan 14, 2025
1 parent 70a6ea0 commit 08c7407
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 60 deletions.
4 changes: 0 additions & 4 deletions src/tests/trainers/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,10 +3059,6 @@ def test_reward_model_from_trainer(self, create_datadreamer, mocker):


class TestTrainSetFitClassifier:
# TODO: SetFit is currently broken (skipping tests):
# https://github.com/huggingface/setfit/issues/564
__test__ = False

def test_metadata(self, create_datadreamer):
with create_datadreamer():
trainer = TrainSetFitClassifier(
Expand Down
16 changes: 9 additions & 7 deletions src/trainers/train_hf_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def _train( # type:ignore[override] # noqa: C901
from ._vendored.dpo_trainer import DPOTrainer # type: ignore[attr-defined]

# Prepare datasets
assert self._is_encoder_decoder or truncate, (
"`truncate=False` is not supported for this model."
)
assert (
self._is_encoder_decoder or truncate
), "`truncate=False` is not supported for this model."
train_dataset, validation_dataset, _, _ = prepare_inputs_and_outputs(
self,
train_columns={
Expand Down Expand Up @@ -358,10 +358,12 @@ def _train( # type:ignore[override] # noqa: C901
prepared_model, trainer.optimizer
)
else:
(prepared_model, trainer.optimizer, trainer.lr_scheduler) = (
trainer.accelerator.prepare(
prepared_model, trainer.optimizer, trainer.lr_scheduler
)
(
prepared_model,
trainer.optimizer,
trainer.lr_scheduler,
) = trainer.accelerator.prepare(
prepared_model, trainer.optimizer, trainer.lr_scheduler
)
trainer.model_wrapped = prepared_model
if trainer.is_fsdp_enabled:
Expand Down
114 changes: 75 additions & 39 deletions src/trainers/train_setfit_classifier.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import base64
import json
import os
import sys
from copy import deepcopy
from functools import cached_property, partial
from pathlib import Path
from typing import Any, Type
from unittest.mock import patch

import dill
import torch

from datasets import IterableDataset

from .. import DataDreamer
from .._cachable._cachable import _is_primitive
from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn
from ..utils.arg_utils import AUTO, DEFAULT, Default, default_to
from ..utils.background_utils import RunIfTimeout
from ..utils.device_utils import _SentenceTransformerTrainingArgumentDeviceOverrideMixin
from ..utils.hf_model_utils import (
filter_model_warnings,
get_base_model_from_peft_model,
Expand All @@ -32,7 +37,10 @@

with ignore_transformers_warnings():
from huggingface_hub.repocard import ModelCard, ModelCardData
from sentence_transformers import SentenceTransformer
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainingArguments,
)
from setfit import SetFitModel, logging as setfit_logging
from transformers import PreTrainedModel
from transformers.trainer_callback import EarlyStoppingCallback, PrinterCallback
Expand Down Expand Up @@ -95,6 +103,14 @@ def _save_pretrained(self, save_directory: Path | str) -> None:
self.model_head.to(self.device)


class _SentenceTransformerTrainingArguments(
_SentenceTransformerTrainingArgumentDeviceOverrideMixin,
SentenceTransformerTrainingArguments,
):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class TrainSetFitClassifier(TrainHFClassifier):
def __init__(
self,
Expand All @@ -110,9 +126,9 @@ def __init__(
**kwargs,
):
cls_name = self.__class__.__name__
assert not isinstance(
device, list
), f"Training on multiple devices is not supported for {cls_name}."
assert not isinstance(device, list), (
f"Training on multiple devices is not supported for {cls_name}."
)
_TrainHFBase.__init__(
self,
name=name,
Expand Down Expand Up @@ -184,7 +200,7 @@ def _create_model(
head_params={"out_features": len(label2id)},
labels=list(label2id.keys()),
**self.kwargs,
)
).to(default_to(device, self.device))

# Set model dtype
model.model_body = model.model_body.to(self.dtype)
Expand Down Expand Up @@ -249,27 +265,24 @@ def _train( # type:ignore[override]
from transformers.trainer_callback import ProgressCallback

# Prepare datasets
(
train_dataset,
validation_dataset,
label2id,
is_multi_target,
) = prepare_inputs_and_outputs(
self,
train_columns={
("text", "Train Input"): train_input,
("label", "Train Output"): train_output,
},
validation_columns={
("text", "Validation Input"): validation_input,
("label", "Validation Output"): validation_output,
},
truncate=truncate,
(train_dataset, validation_dataset, label2id, is_multi_target) = (
prepare_inputs_and_outputs(
self,
train_columns={
("text", "Train Input"): train_input,
("label", "Train Output"): train_output,
},
validation_columns={
("text", "Validation Input"): validation_input,
("label", "Validation Output"): validation_output,
},
truncate=truncate,
)
)
id2label = {v: k for k, v in label2id.items()}
assert (
len(id2label) > 1
), "There must be at least 2 output labels in your dataset."
assert len(id2label) > 1, (
"There must be at least 2 output labels in your dataset."
)

# Prepare metrics
metric = kwargs.pop("metric", "f1")
Expand Down Expand Up @@ -319,6 +332,7 @@ def _train( # type:ignore[override]
seed=seed,
**kwargs,
)
training_args.place_model_on_device = False
training_args.eval_strategy = training_args.evaluation_strategy
if kwargs.get("max_steps", None) is not None: # pragma: no cover
total_train_steps = kwargs["max_steps"]
Expand All @@ -345,6 +359,28 @@ def _train( # type:ignore[override]

# Setup trainer
class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
model_body = kwargs["model"].model_body
orig_body_device = model_body.device
os.environ["_DATADREAMER_SETFIT_DEVICE"] = base64.b64encode(
dill.dumps(orig_body_device)
).decode("utf-8")
with patch(
"sentence_transformers.trainer.SentenceTransformerTrainingArguments",
_SentenceTransformerTrainingArguments,
):
super().__init__(*args, **kwargs)
_ = (
self.st_trainer.args._selected_device
) # Read the property to store it in cache
del os.environ["_DATADREAMER_SETFIT_DEVICE"]

def _wrap_model(self, model, *args, **kwargs):
return model

def _move_model_to_device(self, model, *args, **kwargs):
return None

def train_classifier(trainer, *args, **kwargs):
self.logger.info("Finished training SetFit model body (embeddings).")

Expand Down Expand Up @@ -387,19 +423,19 @@ def train_classifier(trainer, *args, **kwargs):
setfit_logging.disable_progress_bar()
return results

def evaluate(self, *args, **kwargs):
metrics = {
"epoch": "Final"
if "final" in kwargs
else round(self.state.epoch or 0.0, 2)
}
kwargs.pop("final", None)
for k, v in super().evaluate(*args, **kwargs).items():
metrics[f"eval_{k}"] = v
self.callback_handler.on_log(
training_args, self.state, self.control, metrics
)
return metrics
# def evaluate(self, *args, **kwargs):
# metrics = {
# "epoch": "Final"
# if "final" in kwargs
# else round(self.state.epoch or 0.0, 2)
# }
# kwargs.pop("final", None)
# for k, v in super().evaluate(*args, **kwargs).items():
# metrics[f"eval_{k}"] = v
# self.callback_handler.on_log(
# training_args, self.state, self.control, metrics
# )
# return metrics

trainer = CustomTrainer(
train_dataset=train_dataset,
Expand Down Expand Up @@ -539,7 +575,7 @@ def _load_model(
use_differentiable_head=True,
labels=list(label2id.keys()),
**self.kwargs,
)
).to(self.device)
if os.path.exists(
os.path.join(
os.path.join(self._output_folder_path, "_model"),
Expand Down Expand Up @@ -575,7 +611,7 @@ def _load_model(
use_differentiable_head=True,
labels=list(label2id.keys()),
**self.kwargs,
)
).to(self.device)

# Set model dtype
model.model_body = model.model_body.to(self.dtype)
Expand Down
41 changes: 31 additions & 10 deletions src/utils/device_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import base64
import os
import re
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence, Type, cast

import dill
import torch

from .. import DataDreamer
Expand Down Expand Up @@ -85,9 +88,9 @@ def get_true_device_ids(

def get_device_env_variables(devices: list[int | str | torch.device]) -> dict[str, Any]:
_, true_device_ids = get_true_device_ids(devices)
assert (
len(true_device_ids) == len(devices)
), f"The device list you specified ({devices}) is invalid (or devices could not be found)."
assert len(true_device_ids) == len(devices), (
f"The device list you specified ({devices}) is invalid (or devices could not be found)."
)
device_env = {"CUDA_VISIBLE_DEVICES": ",".join(map(str, true_device_ids))}
device_env["NCCL_P2P_DISABLE"] = "1"
return device_env
Expand All @@ -101,11 +104,8 @@ def memory_usage_format(num, suffix="B"): # pragma: no cover
return f"{num:.1f}Yi{suffix}"


class _TrainingArgumentDeviceOverrideMixin:
class __TrainingArgumentDeviceOverrideMixin:
def __init__(self, *args, **kwargs):
from .distributed_utils import apply_distributed_config

kwargs = apply_distributed_config(self, kwargs)
super().__init__(*args, **kwargs)

@property
Expand Down Expand Up @@ -151,6 +151,27 @@ def device(self) -> torch.device:
)


class _SentenceTransformerTrainingArgumentDeviceOverrideMixin(
__TrainingArgumentDeviceOverrideMixin
):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@cached_property
def _selected_device(self):
dill.loads(
base64.b64decode(os.environ["_DATADREAMER_SETFIT_DEVICE"].encode("utf8"))
)


class _TrainingArgumentDeviceOverrideMixin(__TrainingArgumentDeviceOverrideMixin):
def __init__(self, *args, **kwargs):
from .distributed_utils import apply_distributed_config

kwargs = apply_distributed_config(self, kwargs)
super().__init__(*args, **kwargs)


def get_device_memory_monitoring_callback(trainer: "_TrainHFBase") -> Type:
from .distributed_utils import (
get_current_accelerator,
Expand Down Expand Up @@ -259,9 +280,9 @@ def model_to_device(
to_device_map = "auto"
to_device_map_max_memory = max_memory
else:
assert all(
is_cpu_device(d) for d in list_of_devices
), f"The device you specified ({list_of_devices}) is invalid (or devices could not be found)."
assert all(is_cpu_device(d) for d in list_of_devices), (
f"The device you specified ({list_of_devices}) is invalid (or devices could not be found)."
)
to_device_map = {"": "cpu"}
to_device_map_max_memory = None
else:
Expand Down

0 comments on commit 08c7407

Please sign in to comment.