Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more loss function options #159

Merged
merged 7 commits into from
Nov 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ class SetFitTrainer:
warmup_proportion (`float`, *optional*, defaults to `0.1`):
Proportion of the warmup in the total training steps.
Must be greater than or equal to 0.0 and less than or equal to 1.0.
distance_metric (`Callable`, defaults to `BatchHardTripletLossDistanceFunction.cosine_distance`):
Function that returns a distance between two embeddings.
It is set for the triplet loss and
is ignored for `CosineSimilarityLoss` and `SupConLoss`.
margin (`float`, defaults to `0.25`): Margin for the triplet loss.
Negative samples should be at least margin further apart from the anchor than the positive.
This is ignored for `CosineSimilarityLoss`, `BatchHardSoftMarginTripletLoss` and `SupConLoss`.
"""

def __init__(
Expand All @@ -80,6 +87,8 @@ def __init__(
column_mapping: Dict[str, str] = None,
use_amp: bool = False,
warmup_proportion: float = 0.1,
distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance,
margin: float = 0.25,
):
if (warmup_proportion < 0.0) or (warmup_proportion > 1.0):
raise ValueError(
Expand All @@ -98,6 +107,8 @@ def __init__(
self.column_mapping = column_mapping
self.use_amp = use_amp
self.warmup_proportion = warmup_proportion
self.distance_metric = distance_metric
self.margin = margin

if model is None:
if model_init is not None:
Expand Down Expand Up @@ -313,16 +324,15 @@ def train(
if self.loss_class is losses.BatchHardSoftMarginTripletLoss:
train_loss = self.loss_class(
model=self.model.model_body,
distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance,
distance_metric=self.distance_metric,
)
elif self.loss_class is SupConLoss:
train_loss = self.loss_class(model=self.model.model_body)
else:

train_loss = self.loss_class(
model=self.model.model_body,
distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance,
margin=0.25,
distance_metric=self.distance_metric,
margin=self.margin,
)

train_steps = len(train_dataloader) * self.num_epochs
Expand Down