diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 5e96cd60..55354f31 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -304,7 +304,6 @@ def train( num_epochs = num_epochs or self.num_epochs batch_size = batch_size or self.batch_size learning_rate = learning_rate or self.learning_rate - batch_size = batch_size or self.batch_size is_differentiable_head = isinstance(self.model.model_head, torch.nn.Module) # If False, assume using sklearn if not is_differentiable_head or self._freeze: diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py index bd45e1cb..3a2d3b55 100644 --- a/src/setfit/trainer_distillation.py +++ b/src/setfit/trainer_distillation.py @@ -150,7 +150,6 @@ def train( num_epochs = num_epochs or self.num_epochs batch_size = batch_size or self.batch_size learning_rate = learning_rate or self.learning_rate - batch_size = batch_size or self.batch_size is_differentiable_head = isinstance( self.student_model.model_head, torch.nn.Module ) # If False, assume using sklearn