-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NLP Components to Benchmarking #213
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch import Tensor | ||
from transformers import AutoModelForSequenceClassification | ||
|
||
from renate.models import RenateModule | ||
|
||
|
||
class HuggingFaceSequenceClassificationTransformer(RenateModule): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that necessary or could we just use a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the RenateWrapper does not allow to save any additional arguments as in this case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. |
||
"""RenateModule which wraps around Hugging Face transformers. | ||
|
||
Args: | ||
pretrained_model_name: Hugging Face model id. | ||
num_outputs: Number of outputs. | ||
loss_fn: The loss function to be optimized during the training. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
pretrained_model_name: str, | ||
num_outputs: int, | ||
loss_fn: nn.Module = nn.CrossEntropyLoss(), | ||
) -> None: | ||
super().__init__( | ||
constructor_arguments={ | ||
"pretrained_model_name": pretrained_model_name, | ||
"num_outputs": num_outputs, | ||
}, | ||
loss_fn=loss_fn, | ||
) | ||
self._model = AutoModelForSequenceClassification.from_pretrained( | ||
pretrained_model_name, num_labels=num_outputs, return_dict=False | ||
) | ||
|
||
def forward(self, x: Dict[str, Tensor], task_id: Optional[str] = None) -> torch.Tensor: | ||
return self._model(**x)[0] | ||
|
||
def _add_task_params(self, task_id: str) -> None: | ||
assert not len(self._tasks_params_ids), "Transformer does not work for multiple tasks." |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,4 +67,4 @@ | |
|
||
# Noticed different accuracy scores across Mac and GitHub Actions Workflows (which run on Linux) | ||
# TODO see if we can align the Mac and Linux results | ||
assert pytest.approx(test_config["expected_accuracy"]) == accuracies | ||
assert pytest.approx(test_config["expected_accuracy"]) == accuracies, accuracies | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this syntax? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if fails, prints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're disallowing path objects here, should we also do that in the other examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the use of complex types as Path is no longer supported for our *_fn functions since 0.2.0. I can check for more occurrences but I should have removed most of them before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see occurrences of
Path
in most example config files. If you don't want to remove then now, we should create an issue to do it later.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the remaining config files as well as the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!