Skip to content

Commit

Permalink
Bert fixes (#157)
Browse files Browse the repository at this point in the history
* Lowercase normalization should have happened before

* If stripAccents is null, strip when lowercase

See https://docs.rs/tokenizers/latest/src/tokenizers/normalizers/bert.rs.html#119-137

* stripAccents keeps the base character!

* Test case for stripAccents

* A couple of Bert tests with diacritics

* Map distilber tokenizer

* BasicTokenizer lowercasing is now optional

And stripAccents is performed when lowercasing

* Fix decoder regexp

* Couple of no lowercase tests

* Format

* BertNormalizer tests: update for new defaults

* Additional tests

* Punctuation rules

* Edge cases for bert tokenizers

* Remove code copied by mistake
  • Loading branch information
pcuenca authored Jan 17, 2025
1 parent a867fea commit 313fbd7
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 20 deletions.
50 changes: 42 additions & 8 deletions Sources/Tokenizers/BertTokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Foundation
import Hub

public class BertTokenizer {
private let basicTokenizer = BasicTokenizer()
private let basicTokenizer: BasicTokenizer
private let wordpieceTokenizer: WordpieceTokenizer
private let maxLen = 512
private let tokenizeChineseChars: Bool
Expand All @@ -30,10 +30,12 @@ public class BertTokenizer {
tokenizeChineseChars: Bool = true,
bosToken: String? = nil,
eosToken: String? = nil,
fuseUnknownTokens: Bool = false
fuseUnknownTokens: Bool = false,
doLowerCase: Bool = true
) {
self.vocab = vocab
self.ids_to_tokens = Utils.invert(vocab)
self.basicTokenizer = BasicTokenizer(doLowerCase: doLowerCase)
self.wordpieceTokenizer = WordpieceTokenizer(vocab: self.vocab)
self.tokenizeChineseChars = tokenizeChineseChars
self.bosToken = bosToken
Expand All @@ -52,7 +54,8 @@ public class BertTokenizer {
let eosToken = tokenizerConfig.eosToken?.stringValue
let bosToken = tokenizerConfig.bosToken?.stringValue
let fuseUnknown = tokenizerConfig.fuseUnk?.boolValue ?? false
self.init(vocab: vocab, merges: merges, tokenizeChineseChars: tokenizeChineseChars, bosToken: bosToken, eosToken: eosToken, fuseUnknownTokens: fuseUnknown)
let doLowerCase = tokenizerConfig.doLowerCase?.boolValue ?? true
self.init(vocab: vocab, merges: merges, tokenizeChineseChars: tokenizeChineseChars, bosToken: bosToken, eosToken: eosToken, fuseUnknownTokens: fuseUnknown, doLowerCase: doLowerCase)
}


Expand Down Expand Up @@ -154,21 +157,36 @@ extension BertTokenizer: PreTrainedTokenizerModel {


class BasicTokenizer {
let doLowerCase: Bool

init(doLowerCase: Bool = true) {
self.doLowerCase = doLowerCase
}

let neverSplit = [
"[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"
]


func maybeStripAccents(_ text: String) -> String {
guard doLowerCase else { return text }
return text.folding(options: .diacriticInsensitive, locale: nil)
}

func maybeLowercase(_ text: String) -> String {
guard doLowerCase else { return text }
return text.lowercased()
}

func tokenize(text: String) -> [String] {
let splitTokens = text.folding(options: .diacriticInsensitive, locale: nil)
.components(separatedBy: NSCharacterSet.whitespaces)
let splitTokens = maybeStripAccents(text).components(separatedBy: NSCharacterSet.whitespaces)
let tokens = splitTokens.flatMap({ (token: String) -> [String] in
if neverSplit.contains(token) {
return [token]
}
var toks: [String] = []
var currentTok = ""
for c in token.lowercased() {
if c.isLetter || c.isNumber || c == "°" {
for c in maybeLowercase(token) {
if !c.isExtendedPunctuation {
currentTok += String(c)
} else if currentTok.count > 0 {
toks.append(currentTok)
Expand All @@ -187,6 +205,22 @@ class BasicTokenizer {
}
}

extension Character {
/// https://github.com/huggingface/transformers/blob/8c1b5d37827a6691fef4b2d926f2d04fb6f5a9e3/src/transformers/tokenization_utils.py#L367
var isExtendedPunctuation: Bool {
if isPunctuation { return true }
if let value = unicodeScalars.first?.value {
switch value {
case 33...47: return true
case 58...64: return true
case 91...96: return true
case 123...126: return true
default: return false
}
}
return false
}
}

class WordpieceTokenizer {
let unkToken = "[UNK]"
Expand Down
2 changes: 1 addition & 1 deletion Sources/Tokenizers/Decoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class WordPieceDecoder: Decoder {
let cleanup: Bool

// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L31
private let re = try! NSRegularExpression(pattern: "\\s(\\.|\\?|\\!|\\,|'|n't|'m|'s|'ve|'re)", options: [])
private let re = try! NSRegularExpression(pattern: "\\s(\\.|\\?|\\!|\\,|'\\s|n't|'m|'s|'ve|'re)", options: [])

required public init(config: Config) {
guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") }
Expand Down
16 changes: 7 additions & 9 deletions Sources/Tokenizers/Normalizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,14 @@ class NFKCNormalizer: Normalizer {
class BertNormalizer: Normalizer {
let shouldCleanText: Bool
let shouldHandleChineseChars: Bool
let shouldStripAccents: Bool?
let shouldStripAccents: Bool
let shouldLowercase: Bool

required init(config: Config) {
self.shouldCleanText = config.cleanText?.boolValue ?? true
self.shouldHandleChineseChars = config.handleChineseChars?.boolValue ?? true
self.shouldStripAccents = config.stripAccents?.boolValue
self.shouldLowercase = config.lowercase?.boolValue ?? true
self.shouldStripAccents = config.stripAccents?.boolValue ?? shouldLowercase
}

func normalize(text: String) -> String {
Expand All @@ -164,7 +164,7 @@ class BertNormalizer: Normalizer {
if shouldHandleChineseChars {
output = handleChineseChars(text: output)
}
if shouldStripAccents ?? false {
if shouldStripAccents {
output = stripAccents(text: output)
}
if shouldLowercase {
Expand Down Expand Up @@ -219,12 +219,10 @@ class BertNormalizer: Normalizer {
}

private func stripAccents(text: String) -> String {
text.decomposedStringWithCanonicalMapping
.filter {
$0.unicodeScalars.allSatisfy { scalar in
!(0x0300 <= scalar.value && scalar.value <= 0x036F)
}
}
// This might be the same as `text.folding(options: .diacriticInsensitive, locale: nil)`
String(text.decomposedStringWithCanonicalMapping.unicodeScalars.filter { scalar in
!(0x0300 <= scalar.value && scalar.value <= 0x036F)
})
}
}

Expand Down
5 changes: 4 additions & 1 deletion Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ public protocol PreTrainedTokenizerModel: TokenizingModel {
struct TokenizerModel {
static let knownTokenizers: [String : PreTrainedTokenizerModel.Type] = [
"BertTokenizer" : BertTokenizer.self,
"DistilbertTokenizer": BertTokenizer.self,
"DistilBertTokenizer": BertTokenizer.self,
"CodeGenTokenizer" : CodeGenTokenizer.self,
"CodeLlamaTokenizer" : CodeLlamaTokenizer.self,
"FalconTokenizer" : FalconTokenizer.self,
Expand Down Expand Up @@ -270,7 +272,8 @@ public class PreTrainedTokenizer: Tokenizer {
func cleanUp(text: String) -> String {
guard cleanUpTokenizationSpaces else { return text }

return text.replacingOccurrences(of: " .", with: ".")
return text
.replacingOccurrences(of: " .", with: ".")
.replacingOccurrences(of: " ?", with: "?")
.replacingOccurrences(of: " !", with: "!")
.replacingOccurrences(of: " ,", with: ",")
Expand Down
37 changes: 37 additions & 0 deletions Tests/NormalizerTests/NormalizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ class NormalizerTests: XCTestCase {
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? NFKCNormalizer)
}

func testStripAccents() {
let testCases: [(String, String)] = [
("département", "departement"),
]

//TODO: test combinations with/without lowercase
let config = Config(["stripAccents":true])
let normalizer = BertNormalizer(config: config)
for (arg, expect) in testCases {
XCTAssertEqual(normalizer.normalize(text: arg), expect)
}
}

func testBertNormalizer() {
let testCases: [(String, String)] = [
("Café", "café"),
Expand All @@ -133,6 +146,30 @@ class NormalizerTests: XCTestCase {
("\u{00C5}", "\u{00E5}"),
]

for (arg, expect) in testCases {
let config = Config(["stripAccents":false])
let normalizer = BertNormalizer(config: config)
XCTAssertEqual(normalizer.normalize(text: arg), expect)
}

let config = Config(["type": NormalizerType.Bert.rawValue])
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? BertNormalizer)
}

func testBertNormalizerDefaults() {
// Python verification: t._tokenizer.normalizer.normalize_str("Café")
let testCases: [(String, String)] = [
("Café", "cafe"),
("François", "francois"),
("Ωmega", "ωmega"),
("über", "uber"),
("háček", "hacek"),
("Häagen\tDazs", "haagen dazs"),
("你好!", " 你 好 !"),
("𝔄𝔅ℭ⓵⓶⓷︷,︸,i⁹,i₉,㌀,¼", "𝔄𝔅ℭ⓵⓶⓷︷,︸,i⁹,i₉,㌀,¼"),
("Å", "a"),
]

for (arg, expect) in testCases {
let config = Config([:])
let normalizer = BertNormalizer(config: config)
Expand Down
1 change: 1 addition & 0 deletions Tests/TokenizersTests/Resources/bert_uncased_encoded.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"text": "Fatouville-Grestain est une commune du Nord-Ouest du d\u00e9partement de l'Eure situ\u00e9e au \nbord de l'estuaire de la Seine et \u00e0 proximit\u00e9 du d\u00e9partement du Calvados. Selon l'atlas des paysages \nde Haute-Normandie, elle appartient \u00e0 la r\u00e9gion naturelle du Lieuvin. Toutefois, l'Agreste, le service \nde la statistique et de la prospective du minist\u00e8re de l'Agriculture, de l'Agroalimentaire et de la For\u00eat, \nla classe au sein du pays d'Auge (en tant que r\u00e9gion agricole).La commune est \u00e0 moins de dix kilom\u00e8tres \u00e0 \nl'est de Honfleur, \u00e0 autant de Beuzeville et \u00e0 environ dix-sept kilom\u00e8tres de Pont-Audemer.", "bpe_tokens": ["fat", "##ou", "##ville", "-", "gr", "##est", "##ain", "est", "une", "commune", "du", "nord", "-", "ou", "##est", "du", "depart", "##ement", "de", "l", "'", "eu", "##re", "situ", "##ee", "au", "bo", "##rd", "de", "l", "'", "est", "##ua", "##ire", "de", "la", "seine", "et", "a", "pro", "##xi", "##mite", "du", "depart", "##ement", "du", "cal", "##va", "##dos", ".", "se", "##lon", "l", "'", "atlas", "des", "pays", "##ages", "de", "haute", "-", "norman", "##die", ",", "elle", "app", "##art", "##ient", "a", "la", "region", "nature", "##lle", "du", "lieu", "##vin", ".", "to", "##ute", "##fo", "##is", ",", "l", "'", "ag", "##rest", "##e", ",", "le", "service", "de", "la", "stat", "##ist", "##ique", "et", "de", "la", "prospective", "du", "minister", "##e", "de", "l", "'", "agriculture", ",", "de", "l", "'", "ag", "##ro", "##ali", "##ment", "##aire", "et", "de", "la", "fore", "##t", ",", "la", "class", "##e", "au", "se", "##in", "du", "pays", "d", "'", "aug", "##e", "(", "en", "tan", "##t", "que", "region", "ag", "##ric", "##ole", ")", ".", "la", "commune", "est", "a", "moi", "##ns", "de", "di", "##x", "kilometres", "a", "l", "'", "est", "de", "hon", "##fle", "##ur", ",", "a", "au", "##tan", "##t", "de", "be", "##uze", "##ville", "et", "a", "en", "##vir", "##on", "di", "##x", "-", "sept", "kilometres", "de", "pont", "-", "au", "##de", "##mer", "."], "token_ids": [101, 6638, 7140, 3077, 1011, 24665, 4355, 8113, 9765, 16655, 5715, 4241, 13926, 1011, 15068, 4355, 4241, 18280, 13665, 2139, 1048, 1005, 7327, 2890, 26179, 4402, 8740, 8945, 4103, 2139, 1048, 1005, 9765, 6692, 7442, 2139, 2474, 16470, 3802, 1037, 4013, 9048, 23419, 4241, 18280, 13665, 4241, 10250, 3567, 12269, 1012, 7367, 7811, 1048, 1005, 11568, 4078, 12778, 13923, 2139, 18535, 1011, 5879, 10265, 1010, 15317, 10439, 8445, 11638, 1037, 2474, 2555, 3267, 6216, 4241, 22470, 6371, 1012, 2000, 10421, 14876, 2483, 1010, 1048, 1005, 12943, 28533, 2063, 1010, 3393, 2326, 2139, 2474, 28093, 2923, 7413, 3802, 2139, 2474, 17464, 4241, 2704, 2063, 2139, 1048, 1005, 5237, 1010, 2139, 1048, 1005, 12943, 3217, 11475, 3672, 14737, 3802, 2139, 2474, 18921, 2102, 1010, 2474, 2465, 2063, 8740, 7367, 2378, 4241, 12778, 1040, 1005, 15476, 2063, 1006, 4372, 9092, 2102, 10861, 2555, 12943, 7277, 9890, 1007, 1012, 2474, 5715, 9765, 1037, 25175, 3619, 2139, 4487, 2595, 3717, 1037, 1048, 1005, 9765, 2139, 10189, 21031, 3126, 1010, 1037, 8740, 5794, 2102, 2139, 2022, 20395, 3077, 3802, 1037, 4372, 21663, 2239, 4487, 2595, 1011, 17419, 3717, 2139, 21179, 1011, 8740, 3207, 5017, 1012, 102], "decoded_text": "[CLS] fatouville - grestain est une commune du nord - ouest du departement de l'eure situee au bord de l'estuaire de la seine et a proximite du departement du calvados. selon l'atlas des paysages de haute - normandie, elle appartient a la region naturelle du lieuvin. toutefois, l'agreste, le service de la statistique et de la prospective du ministere de l'agriculture, de l'agroalimentaire et de la foret, la classe au sein du pays d'auge ( en tant que region agricole ). la commune est a moins de dix kilometres a l'est de honfleur, a autant de beuzeville et a environ dix - sept kilometres de pont - audemer. [SEP]"}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"text": "Fatouville-Grestain est une commune du Nord-Ouest du d\u00e9partement de l'Eure situ\u00e9e au \nbord de l'estuaire de la Seine et \u00e0 proximit\u00e9 du d\u00e9partement du Calvados. Selon l'atlas des paysages \nde Haute-Normandie, elle appartient \u00e0 la r\u00e9gion naturelle du Lieuvin. Toutefois, l'Agreste, le service \nde la statistique et de la prospective du minist\u00e8re de l'Agriculture, de l'Agroalimentaire et de la For\u00eat, \nla classe au sein du pays d'Auge (en tant que r\u00e9gion agricole).La commune est \u00e0 moins de dix kilom\u00e8tres \u00e0 \nl'est de Honfleur, \u00e0 autant de Beuzeville et \u00e0 environ dix-sept kilom\u00e8tres de Pont-Audemer.", "bpe_tokens": ["Fat", "##ou", "##ville", "-", "G", "##resta", "##in", "est", "une", "commune", "du", "Nord", "-", "Ouest", "du", "d\u00e9partement", "de", "l", "'", "Eure", "situ\u00e9e", "au", "bord", "de", "l", "'", "est", "##uaire", "de", "la", "Seine", "et", "\u00e0", "proximit\u00e9", "du", "d\u00e9partement", "du", "Calvados", ".", "Selon", "l", "'", "atlas", "des", "paysage", "##s", "de", "Haute", "-", "Normandie", ",", "elle", "appartient", "\u00e0", "la", "r\u00e9gion", "naturelle", "du", "Lie", "##uv", "##in", ".", "Toutefois", ",", "l", "'", "A", "##gres", "##te", ",", "le", "service", "de", "la", "statistique", "et", "de", "la", "pro", "##spect", "##ive", "du", "minist\u00e8re", "de", "l", "'", "Agriculture", ",", "de", "l", "'", "A", "##gro", "##alim", "##entaire", "et", "de", "la", "For\u00eat", ",", "la", "classe", "au", "sein", "du", "pays", "d", "'", "Auge", "(", "en", "tant", "que", "r\u00e9gion", "agricole", ")", ".", "La", "commune", "est", "\u00e0", "moins", "de", "dix", "kilom\u00e8tres", "\u00e0", "l", "'", "est", "de", "Hon", "##f", "##leur", ",", "\u00e0", "autant", "de", "Be", "##uze", "##ville", "et", "\u00e0", "environ", "dix", "-", "sept", "kilom\u00e8tres", "de", "Pont", "-", "Aude", "##mer", "."], "token_ids": [101, 48803, 11010, 12043, 118, 144, 84038, 10245, 10176, 10231, 11380, 10168, 12004, 118, 21781, 10168, 16236, 10104, 180, 112, 35935, 15366, 10257, 27482, 10104, 180, 112, 10176, 54154, 10104, 10109, 13682, 10131, 254, 35483, 10168, 16236, 10168, 51934, 119, 20115, 180, 112, 92753, 10139, 93483, 10107, 10104, 17735, 118, 25771, 117, 11117, 52199, 254, 10109, 14387, 37232, 10168, 39710, 67000, 10245, 119, 46573, 117, 180, 112, 138, 68094, 10216, 117, 10141, 11989, 10104, 10109, 29303, 10131, 10104, 10109, 11284, 77229, 11942, 10168, 41853, 10104, 180, 112, 30954, 117, 10104, 180, 112, 138, 46692, 94974, 106895, 10131, 10104, 10109, 86549, 117, 10109, 15702, 10257, 11479, 10168, 13850, 172, 112, 72800, 113, 10110, 14222, 10121, 14387, 50350, 114, 119, 10159, 11380, 10176, 254, 14443, 10104, 23214, 22308, 254, 180, 112, 10176, 10104, 19431, 10575, 55692, 117, 254, 38585, 10104, 14321, 33302, 12043, 10131, 254, 16844, 23214, 118, 25097, 22308, 10104, 23986, 118, 55665, 12371, 119, 102], "decoded_text": "[CLS] Fatouville - Grestain est une commune du Nord - Ouest du d\u00e9partement de l'Eure situ\u00e9e au bord de l'estuaire de la Seine et \u00e0 proximit\u00e9 du d\u00e9partement du Calvados. Selon l'atlas des paysages de Haute - Normandie, elle appartient \u00e0 la r\u00e9gion naturelle du Lieuvin. Toutefois, l'Agreste, le service de la statistique et de la prospective du minist\u00e8re de l'Agriculture, de l'Agroalimentaire et de la For\u00eat, la classe au sein du pays d'Auge ( en tant que r\u00e9gion agricole ). La commune est \u00e0 moins de dix kilom\u00e8tres \u00e0 l'est de Honfleur, \u00e0 autant de Beuzeville et \u00e0 environ dix - sept kilom\u00e8tres de Pont - Audemer. [SEP]"}
2 changes: 1 addition & 1 deletion Tests/TokenizersTests/Resources/tokenizer_tests.json
100755 → 100644

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ class T5TokenizerTests: TokenizerTests {
override class var unknownTokenId: Int? { 2 }
}

class BertCasedTokenizerTests: TokenizerTests {
override class var hubModelName: String? { "distilbert/distilbert-base-multilingual-cased" }
override class var encodedSamplesFilename: String? { "distilbert_cased_encoded" }
override class var unknownTokenId: Int? { 100 }
}

class BertUncasedTokenizerTests: TokenizerTests {
override class var hubModelName: String? { "google-bert/bert-base-uncased" }
override class var encodedSamplesFilename: String? { "bert_uncased_encoded" }
override class var unknownTokenId: Int? { 100 }
}

class GemmaTokenizerTests: TokenizerTests {
override class var hubModelName: String? { "pcuenq/gemma-tokenizer" }
override class var encodedSamplesFilename: String? { "gemma_encoded" }
Expand Down Expand Up @@ -108,6 +120,61 @@ class PhiSimpleTests: XCTestCase {
}
}

class BertDiacriticsTests: XCTestCase {
func testBertCased() async throws {
guard let tokenizer = try await AutoTokenizer.from(pretrained: "distilbert/distilbert-base-multilingual-cased") as? PreTrainedTokenizer else {
XCTFail()
return
}

XCTAssertEqual(tokenizer.encode(text: "mąka"), [101, 181, 102075, 10113, 102])
XCTAssertEqual(tokenizer.tokenize(text: "Car"), ["Car"])
}

func testBertCasedResaved() async throws {
guard let tokenizer = try await AutoTokenizer.from(pretrained: "pcuenq/distilbert-base-multilingual-cased-tokenizer") as? PreTrainedTokenizer else {
XCTFail()
return
}

XCTAssertEqual(tokenizer.encode(text: "mąka"), [101, 181, 102075, 10113, 102])
}

func testBertUncased() async throws {
guard let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased") as? PreTrainedTokenizer else {
XCTFail()
return
}

XCTAssertEqual(tokenizer.tokenize(text: "mąka"), ["ma", "##ka"])
XCTAssertEqual(tokenizer.encode(text: "mąka"), [101, 5003, 2912, 102])
XCTAssertEqual(tokenizer.tokenize(text: "département"), ["depart", "##ement"])
XCTAssertEqual(tokenizer.encode(text: "département"), [101, 18280, 13665, 102])
XCTAssertEqual(tokenizer.tokenize(text: "Car"), ["car"])

XCTAssertEqual(tokenizer.tokenize(text: "€4"), ["", "##4"])
XCTAssertEqual(tokenizer.tokenize(text: "test $1 R2 #3 €4 £5 ¥6 ₣7 ₹8 ₱9 test"), ["test", "$", "1", "r", "##2", "#", "3", "", "##4", "£5", "¥", "##6", "[UNK]", "", "##8", "", "##9", "test"])
}
}

class BertSpacesTests: XCTestCase {
func testEncodeDecode() async throws {
guard let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased") as? PreTrainedTokenizer else {
XCTFail()
return
}

let text = "l'eure"
let tokenized = tokenizer.tokenize(text: text)
XCTAssertEqual(tokenized, ["l", "'", "eu", "##re"])
let encoded = tokenizer.encode(text: text)
XCTAssertEqual(encoded, [101, 1048, 1005, 7327, 2890, 102])
let decoded = tokenizer.decode(tokens: encoded, skipSpecialTokens: true)
// Note: this matches the behaviour of the Python "slow" tokenizer, but the fast one produces "l ' eure"
XCTAssertEqual(decoded, "l'eure")
}
}


struct EncodedTokenizerSamplesDataset: Decodable {
let text: String
Expand Down

0 comments on commit 313fbd7

Please sign in to comment.