Skip to content

Commit

Permalink
Add more loss function options (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay authored Nov 8, 2022
1 parent 6b58b58 commit 641fbfc
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ class SetFitTrainer:
warmup_proportion (`float`, *optional*, defaults to `0.1`):
Proportion of the warmup in the total training steps.
Must be greater than or equal to 0.0 and less than or equal to 1.0.
distance_metric (`Callable`, defaults to `BatchHardTripletLossDistanceFunction.cosine_distance`):
Function that returns a distance between two embeddings.
It is set for the triplet loss and
is ignored for `CosineSimilarityLoss` and `SupConLoss`.
margin (`float`, defaults to `0.25`): Margin for the triplet loss.
Negative samples should be at least margin further apart from the anchor than the positive.
This is ignored for `CosineSimilarityLoss`, `BatchHardSoftMarginTripletLoss` and `SupConLoss`.
"""

def __init__(
Expand All @@ -80,6 +87,8 @@ def __init__(
column_mapping: Dict[str, str] = None,
use_amp: bool = False,
warmup_proportion: float = 0.1,
distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance,
margin: float = 0.25,
):
if (warmup_proportion < 0.0) or (warmup_proportion > 1.0):
raise ValueError(
Expand All @@ -98,6 +107,8 @@ def __init__(
self.column_mapping = column_mapping
self.use_amp = use_amp
self.warmup_proportion = warmup_proportion
self.distance_metric = distance_metric
self.margin = margin

if model is None:
if model_init is not None:
Expand Down Expand Up @@ -313,16 +324,15 @@ def train(
if self.loss_class is losses.BatchHardSoftMarginTripletLoss:
train_loss = self.loss_class(
model=self.model.model_body,
distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance,
distance_metric=self.distance_metric,
)
elif self.loss_class is SupConLoss:
train_loss = self.loss_class(model=self.model.model_body)
else:

train_loss = self.loss_class(
model=self.model.model_body,
distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance,
margin=0.25,
distance_metric=self.distance_metric,
margin=self.margin,
)

train_steps = len(train_dataloader) * self.num_epochs
Expand Down

0 comments on commit 641fbfc

Please sign in to comment.