-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtokenizer.go
126 lines (105 loc) · 2.85 KB
/
tokenizer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright (c) seasonjs. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.
package rwkv
import (
"embed"
"encoding/json"
"fmt"
"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/pretrained"
)
type TokenizerType uint8
const (
Normal TokenizerType = iota
World
)
type Tokenizer interface {
Encode(in string) ([]int, error)
Decode(in []int) string
}
//go:embed 20B_tokenizer.json
var tokenizerFS embed.FS
type NormalTokenizer struct {
tk *tokenizer.Tokenizer
}
func NewNormalTokenizer() (*NormalTokenizer, error) {
f, err := tokenizerFS.Open("20B_tokenizer.json")
if err != nil {
return nil, err
}
dec := json.NewDecoder(f)
var config *tokenizer.Config
err = dec.Decode(&config)
if err != nil {
return nil, err
}
model, err := pretrained.CreateModel(config)
if err != nil {
err := fmt.Errorf("creating Model failed: %v", err)
return nil, err
}
tk := tokenizer.NewTokenizer(model)
// 2. Normalizer
n, err := pretrained.CreateNormalizer(config.Normalizer)
if err != nil {
err = fmt.Errorf("creating Normalizer failed: %v", err)
return nil, err
}
tk.WithNormalizer(n)
// 3. PreTokenizer
preTok, err := pretrained.CreatePreTokenizer(config.PreTokenizer)
if err != nil {
err = fmt.Errorf("creating PreTokenizer failed: %v", err)
return nil, err
}
tk.WithPreTokenizer(preTok)
// 4. PostProcessor
postProcessor, err := pretrained.CreatePostProcessor(config.PostProcessor)
if err != nil {
err = fmt.Errorf("creating PostProcessor failed: %v", err)
return nil, err
}
tk.WithPostProcessor(postProcessor)
// 5. Decoder
decoder, err := pretrained.CreateDecoder(config.Decoder)
if err != nil {
err = fmt.Errorf("creating Decoder failed: %v", err)
return nil, err
}
tk.WithDecoder(decoder)
// 6. AddedVocabulary
specialAddedTokens, addedTokens := pretrained.CreateAddedTokens(config.AddedTokens)
if len(specialAddedTokens) > 0 {
tk.AddSpecialTokens(specialAddedTokens)
}
if len(addedTokens) > 0 {
tk.AddTokens(addedTokens)
}
// 7. TruncationParams
truncParams, err := pretrained.CreateTruncationParams(config.Truncation)
if err != nil {
err = fmt.Errorf("creating TruncationParams failed: %v", err)
return nil, err
}
tk.WithTruncation(truncParams)
// 8. PaddingParams
paddingParams, err := pretrained.CreatePaddingParams(config.Padding)
if err != nil {
err = fmt.Errorf("creating PaddingParams failed: %v", err)
return nil, err
}
tk.WithPadding(paddingParams)
return &NormalTokenizer{tk: tk}, nil
}
func (t *NormalTokenizer) Encode(input string) ([]int, error) {
in := tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(input))
encode, err := t.tk.Encode(in, false)
if err != nil {
return nil, err
}
return encode.Ids, nil
}
func (t *NormalTokenizer) Decode(ids []int) string {
out := t.tk.Decode(ids, false)
return out
}