Skip to content

Commit

Permalink
Merge pull request #353 from Zurnaz/llama_tpu_tokenizer_fix
Browse files Browse the repository at this point in the history
fix: tpu tokenizers errors
  • Loading branch information
henk717 authored May 8, 2023
2 parents cb4af7e + d53726b commit 0f91298
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion modeling/inference_models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def decode_wrapper(self, token_ids, *args, **kwargs):
token_ids = [first]
elif len(token_ids) > 0:
first = int(token_ids[0])
elif token_ids:
elif token_ids is not None and len(token_ids) > 0:
first = token_ids[0]
result = original_decode(self, token_ids, *args, **kwargs)
if first is not None and first in has_prefix_space:
Expand Down
4 changes: 2 additions & 2 deletions modeling/inference_models/hf_mtj.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ModelCapabilities,
)
from modeling.inference_models.hf import HFInferenceModel
from modeling.tokenizer import GenericTokenizer

# This file shouldn't be imported unless using the TPU
assert utils.koboldai_vars.use_colab_tpu
Expand Down Expand Up @@ -193,8 +194,7 @@ def _load(self, save_model: bool, initial_load: bool) -> None:
utils.koboldai_vars.modeldim = int(
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])
)

self.tokenizer = tpu_mtj_backend.tokenizer
self.tokenizer = GenericTokenizer(tpu_mtj_backend.tokenizer)

if (
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
Expand Down

0 comments on commit 0f91298

Please sign in to comment.