-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/setfithead multi target #272
Feature/setfithead multi target #272
Conversation
See my review comments in huggingface#272 for details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these changes! I think we're getting really close. I made some comments based on some recent PRs that got merged since #212 was made. In particular, removing support for numpy arrays in the differentiable head, and removing out_features=1
: 2 is now the minimum.
I've made these changes and pushed them to this PR. Make sure to git pull
them if you want to make more changes of your own. In short, all of the comments that I made with this review are now resolved (but you can still look at them if you want details on why I made some changes in 36f65bb).
As for your comments regarding the trainer.freeze()
, I'm not sure what caused the issue, but it seems to be gone after I made my changes.
I ran some experiments using the multiclass classification for the different heads. DatasetDataset generation scriptfrom setfit import SetFitModel, SetFitTrainer, sample_dataset
from datasets import load_dataset
dataset = load_dataset("SetFit/hate_speech_offensive")
def to_multiclass(sample):
"""
from
(0: 'hate-speech', 1: 'offensive-language' or 2: 'neither')
to
([1, 0]: 'hate-speech', [1, 1]: 'offensive-language' or [0, 0]: 'neither')
"""
label = sample["label"]
sample["label"] = [1 if label == 0 else 0, 1 if label == 1 else 0]
return sample
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8).map(to_multiclass)
eval_dataset = dataset["test"].map(to_multiclass) I want to point out that this isn't a very natural use of a multiclass dataset. That said, I couldn't find an actual multiclass dataset on the Hub. Training ScriptsLogistic Regression headtrainer = SetFitTrainer(
model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
metrics = trainer.evaluate() Differentiable headtrainer = SetFitTrainer(
model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train and evaluate
trainer.freeze() # Freeze the head
trainer.train() # Train only the body
# Unfreeze the head and freeze the body -> head-only training
trainer.unfreeze(keep_body_frozen=True)
trainer.train(
num_epochs=25, # The number of epochs to train the head or the whole model (body and head)
batch_size=16,
body_learning_rate=1e-5, # The body's learning rate
learning_rate=1e-2, # The head's learning rate
l2_weight=0.0, # Weight decay on **both** the body and head. If `None`, will use 0.01.
) And the model for testing was Results
Notes:
To me, this is indicative that the multilabel differentiable head performs equivalently to the logistic regression head. With other words, this PR seems to be successful in adding support to SetFitHead for with multi-label classification! 🎉
|
I always feel thank you for your thoughtful comments and edits. (Also comment of previous Issues) I will check your editing! |
@tomaarsen I think it would be better to submit the edit of the readme in a separate PR. |
I confirmed your change! thank you! Also, I did a similar multi-label experiment and got a similar result. So I think you can merge it into the main! thank you |
If the changes that you have planned to the README related to the changes from this PR, then I think the changes should be included in this PR. That way, the code and README get updated at the same time. I'm glad to hear that your experiments work too! |
I got it! I edited the readme to reflect our changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm satisfied with almost everything in this PR, with one exception. I'm not sure what the best course of action is, nor what the "normal" approach for this is. Perhaps to provide the SetFitDataset
with a label_postprocessing
function that either converts them to floats or to longs, depending on what is needed?
I understand your concern and agree with you. I fix some code for what you want to do. |
src/setfit/data.py
Outdated
@@ -277,6 +277,7 @@ def collate_fn(batch): | |||
|
|||
# convert to tensors | |||
features = {k: torch.Tensor(v).int() for k, v in features.items()} | |||
labels = torch.Tensor(labels).long() | |||
labels = torch.Tensor(labels) | |||
labels = labels.long() if isinstance(label, int) else labels.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my suggestion is to use a type of label
The type of label should be 'int' for the single label classification, but 'List' for multilabel classification.
we can write
label = torch.Tensor(labels).long() if isinstance(label, int) else torch.Tensor(labels).float()
but I felt it is too long so that I fix as I push.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or
labels = labels.long() if len(labels.size())== 1 else labels.float()
whichever you want!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the second solution is best!
labels = labels.long() if len(labels.size()) == 1 else labels.float()
That should accurately measure whether we are in a multitarget situation, even if the user accidentally supplies floats instead of integers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see.
your suggestion makes sense to me!
I will fix that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me now! Thanks for making all of these changes! 🎉
I tried to solve conflict error and fix some error
please check this PR is what you intended.
also, some I failed some test...and I cannot figure out the reason...