From 0d32dd1e05d8d55e2976bcad3238e6817e75f3c2 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 24 Nov 2023 13:00:37 +0100 Subject: [PATCH 1/2] Implement 'batch_size' on model.predict --- docs/source/en/how_to/batch_sizes.mdx | 21 +++++++++++++++++++ src/setfit/modeling.py | 29 ++++++++++++++++++++------- 2 files changed, 43 insertions(+), 7 deletions(-) create mode 100644 docs/source/en/how_to/batch_sizes.mdx diff --git a/docs/source/en/how_to/batch_sizes.mdx b/docs/source/en/how_to/batch_sizes.mdx new file mode 100644 index 00000000..781e7a98 --- /dev/null +++ b/docs/source/en/how_to/batch_sizes.mdx @@ -0,0 +1,21 @@ + +# Batch sizes +In this how-to guide we will explore the effects of increasing the batch sizes in [`SetFitModel.predict`]. + +## What are they? +When processing on GPUs, often times not all data fits on the GPU its VRAM at once. As a result, the data gets split up into **batches** of some often pre-determined batch size. This is done both during training and during inference. In both scenarios, increasing the batch size often has notable consequences to processing efficiency and VRAM memory usage, as transferring data to and from the GPU can be relatively slow. + +For inference, it is often recommended to set the batch size high to get notably quicker processing speeds. + +## In SetFit +The batch size for inference in SetFit is set to 32, but it can be affected by passing a `batch_size` argument to [`SetFitModel.predict`]. For example, on a RTX 3090 with a SetFit model based on the [paraphrase-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-mpnet-base-v2) Sentence Transformer, the following throughputs are reached: + +![setfit_speed_per_batch_size](https://github.com/huggingface/setfit/assets/37621491/c01d391b-aeba-4a4b-83f8-b09970a0d6e6) + + + +Each sentence consists of 11 words in this experiment. + + + +The default batch size of 32 does not result in the highest possible throughput on this hardware. Consider experimenting with the batch size to reach your highest possible throughput. \ No newline at end of file diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 793b2c72..29ba9e61 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -432,11 +432,15 @@ def _freeze_or_not(self, model: nn.Module, to_freeze: bool) -> None: for param in model.parameters(): param.requires_grad = not to_freeze - def encode(self, inputs: List[str], show_progress_bar: Optional[bool] = None) -> Union[torch.Tensor, np.ndarray]: + def encode( + self, inputs: List[str], batch_size: int = 32, show_progress_bar: Optional[bool] = None + ) -> Union[torch.Tensor, np.ndarray]: """Convert input sentences to embeddings using the `SentenceTransformer` body. Args: inputs (`List[str]`): The input sentences to embed. + batch_size (`int`, defaults to `32`): The batch size to use in encoding the sentences to embeddings. + Higher often means faster processing but higher memory usage. show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding. Returns: @@ -445,6 +449,7 @@ def encode(self, inputs: List[str], show_progress_bar: Optional[bool] = None) -> """ return self.model_body.encode( inputs, + batch_size=batch_size, normalize_embeddings=self.normalize_embeddings, convert_to_tensor=self.has_differentiable_head, show_progress_bar=show_progress_bar, @@ -472,12 +477,14 @@ def _output_type_conversion( return outputs def predict( - self, inputs: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None + self, inputs: List[str], batch_size: int = 32, as_numpy: bool = False, show_progress_bar: Optional[bool] = None ) -> Union[torch.Tensor, np.ndarray]: """Predict the various classes. Args: inputs (`List[str]`): The input sentences to predict classes for. + batch_size (`int`, defaults to `32`): The batch size to use in encoding the sentences to embeddings. + Higher often means faster processing but higher memory usage. as_numpy (`bool`, defaults to `False`): Whether to output as numpy array instead. show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding. @@ -490,17 +497,19 @@ def predict( `Union[torch.Tensor, np.ndarray]`: A vector with equal length to the inputs, denoting to which class each input is predicted to belong. """ - embeddings = self.encode(inputs, show_progress_bar=show_progress_bar) + embeddings = self.encode(inputs, batch_size=batch_size, show_progress_bar=show_progress_bar) outputs = self.model_head.predict(embeddings) return self._output_type_conversion(outputs, as_numpy=as_numpy) def predict_proba( - self, inputs: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None + self, inputs: List[str], batch_size: int = 32, as_numpy: bool = False, show_progress_bar: Optional[bool] = None ) -> Union[torch.Tensor, np.ndarray]: """Predict the probabilities of the various classes. Args: inputs (`List[str]`): The input sentences to predict class probabilities for. + batch_size (`int`, defaults to `32`): The batch size to use in encoding the sentences to embeddings. + Higher often means faster processing but higher memory usage. as_numpy (`bool`, defaults to `False`): Whether to output as numpy array instead. show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding. @@ -515,7 +524,7 @@ def predict_proba( `Union[torch.Tensor, np.ndarray]`: A matrix with shape [INPUT_LENGTH, NUM_CLASSES] denoting probabilities of predicting an input as a class. """ - embeddings = self.encode(inputs, show_progress_bar=show_progress_bar) + embeddings = self.encode(inputs, batch_size=batch_size, show_progress_bar=show_progress_bar) outputs = self.model_head.predict_proba(embeddings) return self._output_type_conversion(outputs, as_numpy=as_numpy) @@ -574,11 +583,17 @@ def create_model_card(self, path: str, model_name: Optional[str] = "SetFit Model with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: f.write(model_card_content) - def __call__(self, inputs: List[str]) -> torch.Tensor: + def __call__( + self, inputs: List[str], batch_size: int = 32, as_numpy: bool = False, show_progress_bar: Optional[bool] = None + ) -> Union[torch.Tensor, np.ndarray]: """Predict the various classes. Args: inputs (`List[str]`): The input sentences to predict classes for. + batch_size (`int`, defaults to `32`): The batch size to use in encoding the sentences to embeddings. + Higher often means faster processing but higher memory usage. + as_numpy (`bool`, defaults to `False`): Whether to output as numpy array instead. + show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding. Example: >>> model = SetFitModel.from_pretrained(...) @@ -589,7 +604,7 @@ def __call__(self, inputs: List[str]) -> torch.Tensor: `torch.Tensor`: A vector with equal length to the inputs, denoting to which class each input is predicted to belong. """ - return self.predict(inputs) + return self.predict(inputs, batch_size=batch_size, as_numpy=as_numpy, show_progress_bar=show_progress_bar) def _save_pretrained(self, save_directory: Union[Path, str]) -> None: save_directory = str(save_directory) From 392cf0dd3be2842499f87c7962eb1d78b02a9134 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 24 Nov 2023 13:08:01 +0100 Subject: [PATCH 2/2] Add batch sizes to toctree --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index af85e43b..5f0367c6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -29,6 +29,8 @@ title: Hyperparameter Optimization - local: how_to/knowledge_distillation title: Knowledge Distillation + - local: how_to/batch_sizes + title: Batch Sizes - local: how_to/absa title: Aspect Based Sentiment Analysis - local: how_to/v1.0.0_migration_guide