diff --git a/bert_score/utils.py b/bert_score/utils.py index 649ebeb..3ec0a7c 100644 --- a/bert_score/utils.py +++ b/bert_score/utils.py @@ -109,7 +109,8 @@ def sent_encode(tokenizer, sent): if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): return tokenizer.encode( - sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True + sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.model_max_length, + truncation=True ) else: return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len) @@ -117,7 +118,8 @@ def sent_encode(tokenizer, sent): import transformers if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): - return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len, truncation=True) + return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.model_max_length, + truncation=True) else: return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len)