Skip to content

Commit

Permalink
remove enforcement of non special when adding tokens (#1521)
Browse files Browse the repository at this point in the history
* remove enforcement of non special when adding tokens

* mut no longer needed

* add a small test

* nit

* style

* audit

* ignore cargo audit's own vulnerability

* update

* revert

* remove CVE
  • Loading branch information
ArthurZucker authored Apr 30, 2024
1 parent 71c2a8d commit f2ec3b2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ jobs:
command: clippy
args: --manifest-path ./bindings/python/Cargo.toml --all-targets --all-features -- -D warnings

- name: Install cargo-audit
run: cargo install cargo-audit

- name: Run Audit
uses: actions-rs/cargo@v1
with:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ jobs:
command: test
args: --verbose --manifest-path ./tokenizers/Cargo.toml --doc

- name: Install cargo-audit
run: cargo install cargo-audit

- name: Run Audit
uses: actions-rs/cargo@v1
with:
Expand Down
3 changes: 1 addition & 2 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,7 @@ impl PyTokenizer {
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(PyAddedToken::from(content, Some(false)).get_token())
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.special = false;
} else if let Ok(token) = token.extract::<PyRefMut<PyAddedToken>>() {
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
Expand Down
12 changes: 12 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,15 @@ def test_splitting(self):
"▁▁▁▁▁▁",
"▁.",
]

def test_decode_special(self):
tokenizer = Tokenizer(BPE())
tokenizer.add_tokens([AddedToken("my", special=True), AddedToken("name", special=False), "is", "john", "pair"])

# Can decode single sequences
output = tokenizer.decode([0, 1, 2, 3], skip_special_tokens=False)
assert output == "my name is john"

output = tokenizer.decode([0, 1, 2, 3], skip_special_tokens=True)
assert output == "name is john"
assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True)

0 comments on commit f2ec3b2

Please sign in to comment.