Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/setfithead multi target #272

Merged
merged 17 commits into from
Jan 19, 2023
59 changes: 45 additions & 14 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,18 @@ class SetFitHead(models.Dense):
The embedding dimension from the output of the SetFit body. If `None`, defaults to `LazyLinear`.
out_features (`int`, defaults to `2`):
The number of targets. If set `out_features` to 1 for binary classification, it will be changed to 2 as 2-class classification.
temperature (`float`, defaults to `1.0`):
A logits' scaling factor (i.e., number of targets more than 1).
temperature (`float`):
A logits' scaling factor. Higher values makes the model less confident and higher values makes
it more confident.
eps (`float`, defaults to `1e-5`):
A value for numerical stability when scaling logits.
bias (`bool`, *optional*, defaults to `True`):
Whether to add bias to the head.
device (`torch.device`, str, *optional*):
The device the model will be sent to. If `None`, will check whether GPU is available.
multitarget (`bool`, *optional*, defaults to `True`):
Enable multi-target classification by making `out_features` binary predictions instead
of a single multinomial prediction.
"""

def __init__(
Expand All @@ -129,6 +133,7 @@ def __init__(
eps: float = 1e-5,
bias: bool = True,
device: Optional[Union[torch.device, str]] = None,
multitarget: bool = False,
) -> None:
super(models.Dense, self).__init__() # init on models.Dense's parent: nn.Module

Expand All @@ -149,6 +154,7 @@ def __init__(
self.eps = eps
self.bias = bias
self._device = device or "cuda" if torch.cuda.is_available() else "cpu"
self.multitarget = multitarget

self.to(self._device)
self.apply(self._init_weight)
Expand All @@ -169,7 +175,8 @@ def forward(
make sure to store embeddings under the key: 'sentence_embedding'
and the outputs will be under the key: 'prediction'.
temperature (`float`, *optional*):
A logits' scaling factor when using multi-targets (i.e., number of targets more than 1).
A logits' scaling factor. Higher values makes the model less
confident and higher values makes it more confident.
Will override the temperature given during initialization.
Returns:
[`Dict[str, torch.Tensor]` or `Tuple[torch.Tensor]`]
Expand All @@ -179,12 +186,12 @@ def forward(
if isinstance(features, dict):
assert "sentence_embedding" in features
is_features_dict = True

x = features["sentence_embedding"] if is_features_dict else features
logits = self.linear(x)
logits = logits / (temperature + self.eps)
probs = nn.functional.softmax(logits, dim=-1)

if self.multitarget or self.out_features == 1: # only has one target or multiple targets per item
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
probs = torch.sigmoid(logits / temperature)
else: # multiple classes, one target per item
probs = nn.functional.softmax(logits / temperature, dim=-1)
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
if is_features_dict:
features.update(
{
Expand All @@ -196,20 +203,32 @@ def forward(

return logits, probs

def predict_proba(self, x_test: torch.Tensor) -> torch.Tensor:
def predict_proba(self, x_test: Union[torch.Tensor, "ndarray"]) -> Union[torch.Tensor, "ndarray"]:
is_tensor = isinstance(x_test, torch.Tensor) # Otherwise assume it's ndarray
if not is_tensor:
x_test = torch.Tensor(x_test).to(self.device)
self.eval()

return self(x_test)[1]
out = self(x_test)[1]
if not is_tensor:
return out.detach().cpu().numpy()
return out
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved

def predict(self, x_test: torch.Tensor) -> torch.Tensor:
def predict(self, x_test: Union[torch.Tensor, "ndarray"]) -> Union[torch.Tensor, "ndarray"]:
probs = self.predict_proba(x_test)

out = torch.argmax(probs, dim=-1)

if self.out_features == 1 or self.multitarget:
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
out = np.where(probs >= 0.5, 1, 0) if isinstance(probs, np.ndarray) else torch.where(probs >= 0.5, 1, 0)
# TODO 0.5 is not suitable. I will set this as threshold in Next PR.
else:
out = np.argmax(probs, dim=-1) if isinstance(probs, np.ndarray) else torch.argmax(probs, dim=-1)
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
return out

def get_loss_fn(self):
return torch.nn.CrossEntropyLoss()
if self.out_features == 1 or self.multitarget: # if sigmoid output
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
return torch.nn.BCEWithLogitsLoss()
else:
return torch.nn.CrossEntropyLoss()

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -295,6 +314,8 @@ def fit(
# to model's device
features = {k: v.to(device) for k, v in features.items()}
labels = labels.to(device)
if self.model_head.multitarget:
labels = labels.float()
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved

outputs = self.model_body(features)
if self.normalize_embeddings:
Expand Down Expand Up @@ -404,7 +425,7 @@ def predict_proba(self, x_test: List[str], as_numpy: bool = False) -> Union[torc
outputs = self.model_head.predict_proba(embeddings)

if as_numpy and self.has_differentiable_head:
outputs = outputs.cpu().numpy()
outputs = outputs.cpu().detach().numpy()
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
elif not as_numpy and not self.has_differentiable_head:
outputs = torch.from_numpy(outputs)

Expand Down Expand Up @@ -503,12 +524,22 @@ def _from_pretrained(
else:
head_params = model_kwargs.get("head_params", {})
if use_differentiable_head:
if multi_target_strategy is None:
use_multitarget = False
else:
if multi_target_strategy in ["one-vs-rest", "multi-output"]:
use_multitarget = True
else:
raise ValueError(
f"multi_target_strategy '{multi_target_strategy}' is not supported for differentiable head"
)
# Base `model_head` parameters
# - get the sentence embedding dimension from the `model_body`
# - follow the `model_body`, put `model_head` on the target device
base_head_params = {
"in_features": model_body.get_sentence_embedding_dimension(),
"device": target_device,
"multitarget": use_multitarget,
}
model_head = SetFitHead(**{**head_params, **base_head_params})
else:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,24 @@ def test_setfit_from_pretrained_local_model_with_head(tmp_path):
assert isinstance(model, SetFitModel)


def test_setfithead_multitarget_from_pretrained():
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
use_differentiable_head=True,
multi_target_strategy="one-vs-rest",
head_params={"out_features": 5},
)
assert isinstance(model.model_head, SetFitHead)
assert model.model_head.multitarget
assert isinstance(model.model_head.get_loss_fn(), torch.nn.BCEWithLogitsLoss)

y_pred = model.predict("Test text")
assert len(y_pred) == 5

y_pred_probs = model.predict_proba("Test text", as_numpy=True)
assert not np.isclose(y_pred_probs.sum(), 1) # Should not sum to one


def test_to_logistic_head():
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
devices = (
Expand Down
47 changes: 47 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,53 @@ def compute_metrics(y_pred, y_test):
)


class SetFitTrainerMultilabelDifferentiableTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
multi_target_strategy="one-vs-rest",
use_differentiable_head=True,
head_params={"out_features": 2},
)
self.num_iterations = 1

def test_trainer_multilabel_support_callable_as_metric(self):
dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]})

multilabel_f1_metric = evaluate.load("f1", "multilabel")
multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")

def compute_metrics(y_pred, y_test):
return {
"f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"],
"accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
}

trainer = SetFitTrainer(
model=self.model,
train_dataset=dataset,
eval_dataset=dataset,
metric=compute_metrics,
num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)

# trainer.freeze()
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
trainer.train()

# trainer.unfreeze(keep_body_frozen=False)
trainer.train(5)
metrics = trainer.evaluate()

self.assertEqual(
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
{
"f1": 1.0,
"accuracy": 1.0,
},
metrics,
)


@require_optuna
class TrainerHyperParameterOptunaIntegrationTest(TestCase):
def setUp(self):
Expand Down