Skip to content

Commit

Permalink
Add SetFitModel.to (#229) (#236)
Browse files Browse the repository at this point in the history
* Add SetFitModel.to

* Add docstring for 'to' to clarify that we don't copy

Note that this does mean that we differ from e.g. torch, which creates a moved copy rather than moving the e.g. tensor itself to a different device.

* Add tests for 'to'.

If CUDA is available, then try to move a model from the CPU to the GPU and back to the CPU.

* Reformat test_modeling

Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>

Co-authored-by: Jegor Kitškerkin <jegor.kitskerkin@gmail.com>
  • Loading branch information
tomaarsen and jegork authored Dec 15, 2022
1 parent 01ff536 commit 7e2ae63
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,22 @@ def predict_proba(self, x_test: List[str]) -> Union[torch.Tensor, np.ndarray]:
embeddings = self.model_body.encode(x_test, normalize_embeddings=self.normalize_embeddings)
return self.model_head.predict_proba(embeddings)

def to(self, device: Union[str, torch.device]) -> "SetFitModel":
"""Move this SetFitModel to `device`, and then return `self`. This method does not copy.
Args:
device (Union[str, torch.device]): The identifier of the device to move the model to.
Returns:
SetFitModel: Returns the original model, but now on the desired device.
"""
self.model_body = self.model_body.to(device)

if isinstance(self.model_head, torch.nn.Module):
self.model_head = self.model_head.to(device)

return self

def __call__(self, inputs):
return self.predict(inputs)

Expand Down
28 changes: 28 additions & 0 deletions tests/test_modeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import TestCase

import numpy as np
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
Expand Down Expand Up @@ -209,3 +210,30 @@ def test_setfit_from_pretrained_local_model_with_head(tmp_path):
model = SetFitModel.from_pretrained(str(tmp_path.absolute()))

assert isinstance(model, SetFitModel)


def test_to_logistic_head():
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
devices = (
[torch.device("cpu"), torch.device("cuda", 0), torch.device("cpu")]
if torch.cuda.is_available()
else [torch.device("cpu")]
)
for device in devices:
model.to(device)
assert model.model_body.device == device


def test_to_torch_head():
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=True
)
devices = (
[torch.device("cpu"), torch.device("cuda", 0), torch.device("cpu")]
if torch.cuda.is_available()
else [torch.device("cpu")]
)
for device in devices:
model.to(device)
assert model.model_body.device == device
assert model.model_head.device == device

0 comments on commit 7e2ae63

Please sign in to comment.