Skip to content

Commit

Permalink
No longer needlessly deepcopy the original model state (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen authored Dec 12, 2022
1 parent acb99fd commit 46fcd9f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
3 changes: 2 additions & 1 deletion scripts/setfit/run_fewshot_multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, args: argparse.Namespace) -> None:
self.args.model, max_seq_length=args.max_seq_length, add_normalization_layer=args.add_normalization_layer
)
self.model = self.model_wrapper.model
self.model_original_state = copy.deepcopy(self.model.state_dict())

def get_classifier(self, sbert_model: SentenceTransformer) -> SKLearnWrapper:
if self.args.classifier == "logistic_regression":
Expand All @@ -113,7 +114,7 @@ def get_classifier(self, sbert_model: SentenceTransformer) -> SKLearnWrapper:

def train(self, data: Dataset) -> SKLearnWrapper:
"Trains a SetFit model on the given few-shot training data."
self.model.load_state_dict(copy.deepcopy(self.model_wrapper.model_original_state))
self.model.load_state_dict(copy.deepcopy(self.model_original_state))

x_train = data["text"]
y_train = data.remove_columns("text").to_pandas().values
Expand Down
3 changes: 0 additions & 3 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -41,7 +40,6 @@
class SetFitBaseModel:
def __init__(self, model, max_seq_length: int, add_normalization_layer: bool) -> None:
self.model = SentenceTransformer(model)
self.model_original_state = copy.deepcopy(self.model.state_dict())
self.model.max_seq_length = max_seq_length

if add_normalization_layer:
Expand Down Expand Up @@ -208,7 +206,6 @@ def __init__(
self.multi_target_strategy = multi_target_strategy
self.l2_weight = l2_weight

self.model_original_state = copy.deepcopy(self.model_body.state_dict())
self.normalize_embeddings = normalize_embeddings

def fit(
Expand Down

0 comments on commit 46fcd9f

Please sign in to comment.