Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental: tokenizers with and without templates #168

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,31 @@ let package = Package(
name: "swift-transformers",
platforms: [.iOS(.v16), .macOS(.v13)],
products: [
.library(name: "Hub", targets: ["Hub"]),
// ^ Hub client library
.library(name: "Tokenizers", targets: ["Tokenizers"]),
// ^ Tokenizers. Includes `Hub` to download config files
.library(name: "TokenizersTemplates", targets: ["TokenizersTemplates"]),
// ^ Optionally depend on this to add chat template support to Tokenizers
.library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]),
// ^ Everything, including Core ML inference
.executable(name: "transformers", targets: ["TransformersCLI"]),
.executable(name: "hub-cli", targets: ["HubCLI"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"),
.package(url: "https://github.com/johnmai-dev/Jinja", from: "1.1.0")
.package(url: "https://github.com/johnmai-dev/Jinja", from: "1.1.0"),
],
targets: [
.executableTarget(
name: "TransformersCLI",
dependencies: [
"Models", "Generation", "Tokenizers",
.product(name: "ArgumentParser", package: "swift-argument-parser")]),
dependencies: [ "Models", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),
.target(name: "TokenizersCore", dependencies: ["Hub"], path: "Sources/Tokenizers"),
.target(name: "TokenizersTemplates", dependencies: ["TokenizersCore", .product(name: "Jinja", package: "Jinja")]),
.target(name: "Tokenizers", dependencies: ["TokenizersCore", .product(name: "Jinja", package: "Jinja")], path: "Sources/TokenizersWrapper"),
// ^ This is just a wrapper or façade against TokenizersCore, but adds templates if available
.target(name: "TensorUtils"),
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
Expand Down
213 changes: 36 additions & 177 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import Hub
import Foundation
import Jinja

enum TokenizerError: Error {
public enum TokenizerError: Error {
case missingConfig
case missingTokenizerClassInConfig
case unsupportedTokenizer(String)
Expand Down Expand Up @@ -43,7 +42,7 @@ public protocol TokenizingModel {
}

// Helper - possibly to be moved somewhere else
func addedTokenAsString(_ addedToken: Config?) -> String? {
public func addedTokenAsString(_ addedToken: Config?) -> String? {
guard let addedToken = addedToken else { return nil }
if let stringValue = addedToken.stringValue {
return stringValue
Expand Down Expand Up @@ -161,6 +160,20 @@ public protocol Tokenizer {
) throws -> [Int]
}

extension Tokenizer {
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, truncation: false, maxLength: nil, tools: nil)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true, truncation: false, maxLength: nil, tools: nil)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true, truncation: false, maxLength: nil, tools: nil)
}
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in TokenizersTemplates

public extension Tokenizer {
func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] {
encode(text: text, addSpecialTokens: addSpecialTokens)
Expand All @@ -179,18 +192,8 @@ public extension Tokenizer {
}
}

let specialTokenAttributes: [String] = [
"bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"additional_special_tokens"
]

public class PreTrainedTokenizer: Tokenizer {
// open because we have to subclass from `TokenizersTemplates`
open class PreTrainedTokenizer: Tokenizer {
let model: TokenizingModel

public var bosToken: String? { model.bosToken }
Expand All @@ -201,17 +204,17 @@ public class PreTrainedTokenizer: Tokenizer {
public var unknownTokenId: Int? { model.unknownTokenId }
public var fuseUnknownTokens: Bool { model.fuseUnknownTokens }

private let addedTokens: Set<String>
private let specialTokens: [String: Int]
private let addedTokensRegex: NSRegularExpression?
let addedTokens: Set<String>
let specialTokens: [String: Int]
let addedTokensRegex: NSRegularExpression?

private let preTokenizer: PreTokenizer?
private let normalizer: Normalizer?
private let postProcessor: PostProcessor?
private let decoder: Decoder?
private let tokenizerConfig: Config
let preTokenizer: PreTokenizer?
let normalizer: Normalizer?
let postProcessor: PostProcessor?
let decoder: Decoder?
public let tokenizerConfig: Config

private let cleanUpTokenizationSpaces: Bool
let cleanUpTokenizationSpaces: Bool

required public init(tokenizerConfig: Config, tokenizerData: Config) throws {
var addedTokens: [String : Int] = [:]
Expand Down Expand Up @@ -359,19 +362,19 @@ public class PreTrainedTokenizer: Tokenizer {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
open func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
open func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
open func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
}

public func applyChatTemplate(
open func applyChatTemplate(
messages: [[String: String]],
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
Expand All @@ -385,130 +388,7 @@ public class PreTrainedTokenizer: Tokenizer {
/// Note: tool calling is not supported yet, it will be available in a future update.
tools: [[String: Any]]? = nil
) throws -> [Int] {
var selectedChatTemplate: String?
if let chatTemplate, case .literal(let template) = chatTemplate {
// Use chat template from argument
selectedChatTemplate = template
} else if let valueFromConfig = tokenizerConfig.chatTemplate {
if let arrayValue = valueFromConfig.arrayValue {
// If the config specifies a list of chat templates, convert them to a dictionary
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: arrayValue.compactMap { item in
guard let name = item.name?.stringValue, let template = item.template?.stringValue else {
return nil
}
return (name, template)
})
if let chatTemplate, case .name(let name) = chatTemplate {
// Select chat template from config by name
if let matchingDictEntry = templateDict[name] {
selectedChatTemplate = matchingDictEntry
} else {
throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config")
}
} else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
// Use tool use chat template from config
selectedChatTemplate = toolUseTemplate
} else if let defaultChatTemplate = templateDict["default"] {
// Use default chat template from config
selectedChatTemplate = defaultChatTemplate
}
} else if let stringValue = valueFromConfig.stringValue {
// Use chat template from config
selectedChatTemplate = stringValue
}
}

guard let selectedChatTemplate else {
throw TokenizerError.chatTemplate("No chat template was specified")
}

let template = try Template(selectedChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
// TODO: Add `tools` entry when support is added in Jinja
// "tools": tools
]

// TODO: maybe keep NSString here
for (key, value) in tokenizerConfig.dictionary as [String : Any] {
if specialTokenAttributes.contains(key), !(value is NSNull) {
if let stringValue = value as? String {
context[key] = stringValue
} else if let dictionary = value as? [NSString:Any] {
context[key] = addedTokenAsString(Config(dictionary))
} else {
context[key] = value
}
}
}

let rendered = try template.render(context)
var encodedTokens = encode(text: rendered, addSpecialTokens: false)
var maxLength = maxLength ?? encodedTokens.count
maxLength = min(maxLength, tokenizerConfig.modelMaxLength?.intValue ?? maxLength)
if encodedTokens.count > maxLength {
if truncation {
encodedTokens = Array(encodedTokens.prefix(maxLength))
}
}

return encodedTokens
}
}

// MARK: - Building

public struct AutoTokenizer {}

struct PreTrainedTokenizerClasses {
/// Class overrides for custom behaviour
/// Not to be confused with the TokenizerModel classes defined in TokenizerModel
static let tokenizerClasses: [String : PreTrainedTokenizer.Type] = [
"LlamaTokenizer": LlamaPreTrainedTokenizer.self
]
}

extension AutoTokenizer {
static func tokenizerClass(for tokenizerConfig: Config) -> PreTrainedTokenizer.Type {
guard let tokenizerClassName = tokenizerConfig.tokenizerClass?.stringValue else {
return PreTrainedTokenizer.self
}

// Some tokenizer_class entries use a Fast suffix
let tokenizerName = tokenizerClassName.replacingOccurrences(of: "Fast", with: "")
if let tokenizerClass = PreTrainedTokenizerClasses.tokenizerClasses[tokenizerName] {
return tokenizerClass
}

return PreTrainedTokenizer.self
}

public static func from(tokenizerConfig: Config, tokenizerData: Config) throws -> Tokenizer {
let tokenizerClass = tokenizerClass(for: tokenizerConfig)
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

public static func from(
pretrained model: String,
hubApi: HubApi = .shared
) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi)
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

public static func from(
modelFolder: URL,
hubApi: HubApi = .shared
) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi)
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
throw TokenizerError.chatTemplate("Not implemented, you may want to use the `TokenizersWithTemplates` target.")
}
}

Expand All @@ -529,12 +409,13 @@ class T5Tokenizer : UnigramTokenizer {}

// MARK: - PreTrainedTokenizer classes

let sentencePieceUnderline = ""
// These need to be public to be visible from the wrapper factory

public let sentencePieceUnderline = ""

// Hack for Llama tokenizers, see https://github.com/huggingface/transformers/blob/bcb841f0073fcd7a4fb88ea8064313c17dcab04a/src/transformers/models/llama/tokenization_llama_fast.py#L181
// Return updated config, or nil
func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?) throws -> Config? {

public func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?) throws -> Config? {
// If it's already a Template processor (instead of a ByteLevel one), assume it's correct
let postProcessor = PostProcessorFactory.fromConfig(config: processorConfig)
guard !(postProcessor is TemplateProcessing) else { return nil }
Expand Down Expand Up @@ -573,25 +454,3 @@ func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?)
let postProcessorConfig = Config(["type": PostProcessorType.TemplateProcessing.rawValue, "single": single, "pair": pair])
return postProcessorConfig
}

// See https://github.com/xenova/transformers.js/blob/1a9964fb09b8f54fcbeac46dc6aae8d76795809d/src/tokenizers.js#L3203 for these exceptions
class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
let isLegacy: Bool

required init(tokenizerConfig: Config, tokenizerData: Config) throws {
isLegacy = tokenizerConfig.legacy?.boolValue ?? true
var configDictionary = tokenizerData.dictionary
if !isLegacy {
configDictionary.removeValue(forKey: "normalizer")
configDictionary["pre_tokenizer"] = ["type": "Metaspace", "replacement": sentencePieceUnderline, "add_prefix_space": true, "prepend_scheme": "first"]
}

if let postProcessorConfig = try maybeUpdatePostProcessor(tokenizerConfig: tokenizerConfig, processorConfig: tokenizerData.postProcessor) {
configDictionary["post_processor"] = postProcessorConfig.dictionary
}

let updatedData = Config(configDictionary)
try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData)
}
}

Loading