forked from BlinkDL/RWKV-LM
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathindex.html
111 lines (90 loc) · 3.24 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title></title>
<script src="ort.min.js"></script>
</head>
<body>
<script type="text/javascript">
function greedy_sampling(x) {
var max_k = 0
var max_v = x[0]
for (var i = 1; i < 50277; i++) {
if (x[i] > max_v) {
max_v = x[i]
max_k = i
}
}
return max_k
}
// use an async context to call onnxruntime functions.
async function main() {
try {
ort.env.logLevel = "verbose"
ort.env.logLevelInternal = "verbose"
const sessionOption = {
//executionProviders: ['webgl'],
graphOptimizationLevel: 'all'
}
// create a new session and load the specific model.
const session = await ort.InferenceSession.create('./rwkv.onnx', sessionOption)
const idx_d = new Int32Array(1024)
const xx_att_d = new Float32Array(12*768)
const aa_att_d = new Float32Array(12*768)
const bb_att_d = new Float32Array(12*768)
const pp_att_d = new Float32Array(12*768)
const xx_ffn_d = new Float32Array(12*768)
pp_att_d.fill(-1e30)
const idx = new ort.Tensor('int32', idx_d, [1024])
const xx_att = new ort.Tensor('float32', xx_att_d, [12, 768])
const aa_att = new ort.Tensor('float32', aa_att_d, [12, 768])
const bb_att = new ort.Tensor('float32', bb_att_d, [12, 768])
const pp_att = new ort.Tensor('float32', pp_att_d, [12, 768])
const xx_ffn = new ort.Tensor('float32', xx_ffn_d, [12, 768])
// prepare feeds. use model input names as keys.
var feeds = { idx: idx, xx_att: xx_att, aa_att: aa_att, bb_att: bb_att, pp_att: pp_att, xx_ffn: xx_ffn }
var fetches = { xx_att_r: null, aa_att_r: null, bb_att_r: null, pp_att_r: null, xx_ffn_r: null }
// "\nIn a shocking finding"
var prompt = [187, 688, 247, 29103, 4560]
var gen = 64
var len = gen + prompt.length
var ctx = [ prompt.shift() ]
var start = performance.now()
// feed inputs and run
for (var i = 0; i < len; i++) {
// RWKV only looks at the very last token in the context
// It may as well have been a single int32 input, but ONNX doesn't allow to pass just an int32, it has to be a tensor
// So instead of having a single-element tensor, I decided to leave it ctx_len sized
idx.data[1023] = ctx.at(-1)
var results = await session.run(feeds, fetches)
if (prompt.length == 0) {
var token = greedy_sampling(results.x.data)
ctx.push( token )
} else {
ctx.push( prompt.shift() )
// Mode switch
// This will enable the final matmul
if (prompt.length == 0) {
fetches.x = null
console.log("Prompting:", (performance.now() - start) / (i + 1), "ms/token")
start = performance.now()
}
}
feeds.xx_att = results.xx_att_r
feeds.aa_att = results.aa_att_r
feeds.bb_att = results.bb_att_r
feeds.pp_att = results.pp_att_r
feeds.xx_ffn = results.xx_ffn_r
console.log(1 + i, "/", len)
}
console.log("Generation:", (performance.now() - start) / (gen + 1), "ms/token")
console.log(ctx + "")
} catch (e) {
console.log(`failed to inference ONNX model: ${e}.`)
}
}
main()
</script>
</body>
</html>