Skip to content

Commit

Permalink
Remove the check for start and stop tokens in the LLM bundler.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697029482
  • Loading branch information
hheydary authored and copybara-github committed Nov 16, 2024
1 parent 3eb8983 commit 0cebcc0
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 44 deletions.
13 changes: 0 additions & 13 deletions mediapipe/tasks/python/genai/bundler/llm_bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,6 @@ def _validate_config(config: BundleConfig):
"Please ensure you are passing a valid SentencePiece model."
) from e

encoded_start_token = sp.PieceToId(config.start_token)
if encoded_start_token == sp.unk_id():
raise ValueError(
f"Failed to encode start token {config.start_token} with tokenizer."
)

for stop_token in config.stop_tokens:
encoded_stop_token = sp.PieceToId(stop_token)
if encoded_stop_token == sp.unk_id():
raise ValueError(
f"Failed to encode stop token {stop_token} with tokenizer."
)


def create_bundle(config: BundleConfig):
"""Creates a bundle from the given config."""
Expand Down
31 changes: 0 additions & 31 deletions mediapipe/tasks/python/genai/bundler/llm_bundler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for llm_bundler."""

import os
import string
import zipfile
Expand Down Expand Up @@ -147,35 +145,6 @@ def test_invalid_stop_tokens_raises_value_error(self):
with self.assertRaisesRegex(ValueError, "stop_tokens must be non-empty"):
llm_bundler.create_bundle(config)

def test_invalid_start_stop_tokens_raises_value_error(self):
tempdir = self.create_tempdir()
sp_file_path = self._create_sp_model(tempdir.full_path)
tflite_file_path = self._create_tflite_model(tempdir.full_path)
output_file = os.path.join(tempdir, "test.task")
config = llm_bundler.BundleConfig(
tflite_model=tflite_file_path,
tokenizer_model=sp_file_path,
start_token="invalid_token",
stop_tokens=[self.EOS],
output_filename=output_file,
)
with self.assertRaisesRegex(
ValueError, "Failed to encode start token invalid_token with tokenizer"
):
llm_bundler.create_bundle(config)

config = llm_bundler.BundleConfig(
tflite_model=tflite_file_path,
tokenizer_model=sp_file_path,
start_token=self.BOS,
stop_tokens=["invalid_token"],
output_filename=output_file,
)
with self.assertRaisesRegex(
ValueError, "Failed to encode stop token invalid_token with tokenizer"
):
llm_bundler.create_bundle(config)

def test_invalid_tokenizer_model_raises_value_error(self):
tempdir = self.create_tempdir()
sp_file_path = self._create_sp_model(tempdir.full_path, corrupt=True)
Expand Down

0 comments on commit 0cebcc0

Please sign in to comment.