From d9043717ee4c43b5034ee72d7f6e732b045b9e77 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 23 Nov 2022 14:51:17 +0100 Subject: [PATCH 1/2] No longer needlessly deepcopy the original model state --- src/setfit/modeling.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 10f37cec..d8c8fb33 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -1,4 +1,3 @@ -import copy import os from dataclasses import dataclass from pathlib import Path @@ -41,7 +40,6 @@ class SetFitBaseModel: def __init__(self, model, max_seq_length: int, add_normalization_layer: bool) -> None: self.model = SentenceTransformer(model) - self.model_original_state = copy.deepcopy(self.model.state_dict()) self.model.max_seq_length = max_seq_length if add_normalization_layer: @@ -208,7 +206,6 @@ def __init__( self.multi_target_strategy = multi_target_strategy self.l2_weight = l2_weight - self.model_original_state = copy.deepcopy(self.model_body.state_dict()) self.normalize_embeddings = normalize_embeddings def fit( From eceea342a959520b23a75f4d3cd5e04be020bece Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 23 Nov 2022 15:36:48 +0100 Subject: [PATCH 2/2] Modify script to accompany removal of model_original_state --- scripts/setfit/run_fewshot_multilabel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/setfit/run_fewshot_multilabel.py b/scripts/setfit/run_fewshot_multilabel.py index 6438c923..d0e8ad90 100644 --- a/scripts/setfit/run_fewshot_multilabel.py +++ b/scripts/setfit/run_fewshot_multilabel.py @@ -100,6 +100,7 @@ def __init__(self, args: argparse.Namespace) -> None: self.args.model, max_seq_length=args.max_seq_length, add_normalization_layer=args.add_normalization_layer ) self.model = self.model_wrapper.model + self.model_original_state = copy.deepcopy(self.model.state_dict()) def get_classifier(self, sbert_model: SentenceTransformer) -> SKLearnWrapper: if self.args.classifier == "logistic_regression": @@ -113,7 +114,7 @@ def get_classifier(self, sbert_model: SentenceTransformer) -> SKLearnWrapper: def train(self, data: Dataset) -> SKLearnWrapper: "Trains a SetFit model on the given few-shot training data." - self.model.load_state_dict(copy.deepcopy(self.model_wrapper.model_original_state)) + self.model.load_state_dict(copy.deepcopy(self.model_original_state)) x_train = data["text"] y_train = data.remove_columns("text").to_pandas().values