From e86f08a8080511e4ee06c48fdbe654528729826b Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Wed, 19 Jun 2024 15:13:41 +0200 Subject: [PATCH] Mistral nits (#57) * fix(mistral): add missing sentencepiece dependency It is required for Mistral models. * feat(mistral): added bfloat16 dtyle by default * chore(models): add warning about implicit conversion to bfloat16 --- optimum/tpu/modeling_mistral.py | 9 +++++++++ pyproject.toml | 3 ++- text-generation-inference/server/pyproject.toml | 1 + 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/optimum/tpu/modeling_mistral.py b/optimum/tpu/modeling_mistral.py index 1e8a79f1..f214f141 100644 --- a/optimum/tpu/modeling_mistral.py +++ b/optimum/tpu/modeling_mistral.py @@ -1630,3 +1630,12 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Unless specified otherwise, the model weights type will be bfloat16 + if "torch_dtype" not in kwargs: + logger.warning_once("Defaulting to `torch_dtype=torch.bfloat16` for this model") + torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) + # forward to base implementation + return super().from_pretrained(pretrained_model_name_or_path, *model_args, torch_dtype=torch_dtype, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 1da6a85b..01968a49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,8 @@ dependencies = [ "transformers == 4.41.1", "torch >= 2.3.0, <= 2.4.0", "torch-xla[tpu] >= 2.3.0, <= 2.4.0", - "loguru == 0.6.0" + "loguru == 0.6.0", + "sentencepiece == 0.2.0", ] [tool.setuptools_scm] diff --git a/text-generation-inference/server/pyproject.toml b/text-generation-inference/server/pyproject.toml index b15263de..40985701 100644 --- a/text-generation-inference/server/pyproject.toml +++ b/text-generation-inference/server/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ 'safetensors == 0.4.2', 'transformers == 4.41.1', 'loguru == 0.6.0', + "sentencepiece == 0.2.0", ] [tool.setuptools]