-
Notifications
You must be signed in to change notification settings - Fork 180
/
Copy pathtest_page.cu
208 lines (192 loc) · 8.78 KB
/
test_page.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <flashinfer/page.cuh>
#include <type_traits>
#include "cpu_reference.h"
#include "utils.h"
using namespace flashinfer;
template <typename T>
void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, size_t num_heads,
size_t head_dim, QKVLayout kv_layout) {
// number of conversation rounds
size_t num_conv_rounds = 3;
size_t max_decode_len = 1;
size_t max_prefill_len = 128;
size_t max_num_pages =
num_conv_rounds * batch_size * ((max_decode_len + max_prefill_len) / page_size + 1);
std::vector<T> k_data_cpu(max_num_pages * page_size * num_heads * head_dim);
std::vector<T> v_data_cpu(max_num_pages * page_size * num_heads * head_dim);
utils::vec_zero_(k_data_cpu);
utils::vec_zero_(v_data_cpu);
thrust::device_vector<T> k_data_gpu(k_data_cpu), v_data_gpu(v_data_cpu);
std::vector<int32_t> seq_len(batch_size);
utils::vec_fill_(seq_len, 0);
std::vector<std::vector<int32_t>> page_indices(batch_size);
std::vector<int32_t> last_page_len(batch_size);
utils::vec_fill_(last_page_len, 0);
size_t page_counter = 0;
for (size_t round = 0; round < 2 * num_conv_rounds; ++round) {
std::vector<int32_t> append_len(batch_size);
std::vector<int32_t> append_indptr{0};
std::vector<int32_t> batch_indices;
std::vector<int32_t> positions;
std::vector<std::vector<T>> keys;
std::vector<std::vector<T>> values;
if (round % 2 == 0) {
utils::vec_randint_(append_len, 1, max_prefill_len + 1);
} else {
utils::vec_fill_<int32_t>(append_len, max_decode_len);
}
for (size_t i = 0; i < batch_size; ++i) {
append_indptr.push_back(append_indptr.back() + append_len[i]);
seq_len[i] += append_len[i];
for (size_t j = 0; j < append_len[i]; ++j) {
if (last_page_len[i] % page_size == 0) {
page_indices[i].push_back(page_counter++);
last_page_len[i] = 1;
} else {
last_page_len[i] += 1;
}
batch_indices.push_back(i);
positions.push_back(seq_len[i] - append_len[i] + j);
}
std::vector<T> ki(append_len[i] * num_heads * head_dim),
vi(append_len[i] * num_heads * head_dim);
utils::vec_normal_(ki);
utils::vec_normal_(vi);
keys.push_back(ki);
values.push_back(vi);
}
std::vector<int32_t> indptr_cpu{0};
std::vector<int32_t> indices_cpu;
for (size_t i = 0; i < batch_size; ++i) {
for (size_t j = 0; j < page_indices[i].size(); ++j) {
indices_cpu.push_back(page_indices[i][j]);
}
indptr_cpu.push_back(indptr_cpu.back() + page_indices[i].size());
}
paged_kv_t<T, int32_t> paged_kv_cpu(num_heads, page_size, head_dim, batch_size, kv_layout,
/*k_data=*/k_data_cpu.data(),
/*v_data=*/v_data_cpu.data(), indices_cpu.data(),
indptr_cpu.data(), last_page_len.data());
cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr);
thrust::device_vector<int32_t> indptr_gpu(indptr_cpu);
thrust::device_vector<int32_t> indices_gpu(indices_cpu);
thrust::device_vector<int32_t> last_page_len_gpu(last_page_len);
paged_kv_t<T, int32_t> paged_kv_gpu(num_heads, page_size, head_dim, batch_size, kv_layout,
/*k_data=*/thrust::raw_pointer_cast(k_data_gpu.data()),
/*v_data=*/thrust::raw_pointer_cast(v_data_gpu.data()),
thrust::raw_pointer_cast(indices_gpu.data()),
thrust::raw_pointer_cast(indptr_gpu.data()),
thrust::raw_pointer_cast(last_page_len_gpu.data()));
thrust::device_vector<int32_t> batch_indices_gpu(batch_indices);
thrust::device_vector<int32_t> positions_gpu(positions);
thrust::device_vector<T> keys_gpu(append_indptr.back() * num_heads * head_dim);
thrust::device_vector<T> values_gpu(append_indptr.back() * num_heads * head_dim);
for (size_t i = 0; i < batch_size; ++i) {
thrust::device_vector<T> ki(keys[i]);
thrust::device_vector<T> vi(values[i]);
thrust::copy(ki.begin(), ki.end(),
keys_gpu.begin() + append_indptr[i] * num_heads * head_dim);
thrust::copy(vi.begin(), vi.end(),
values_gpu.begin() + append_indptr[i] * num_heads * head_dim);
}
if (round % 2 == 0) {
// call prefill kernel
cudaError_t status =
AppendPagedKVCache(paged_kv_gpu, thrust::raw_pointer_cast(keys_gpu.data()),
thrust::raw_pointer_cast(values_gpu.data()),
thrust::raw_pointer_cast(batch_indices_gpu.data()),
thrust::raw_pointer_cast(positions_gpu.data()),
/*nnz=*/append_indptr.back(),
/*append_k_stride_n=*/num_heads * head_dim,
/*append_k_stride_h=*/head_dim,
/*append_v_stride_n=*/num_heads * head_dim,
/*append_v_stride_h=*/head_dim);
EXPECT_EQ(status, cudaSuccess) << "AppendPagedKVCache kernel launch failed, error message: "
<< cudaGetErrorString(status);
} else {
// call decode kernel
cudaError_t status =
AppendPagedKVCacheDecode(paged_kv_gpu, thrust::raw_pointer_cast(keys_gpu.data()),
thrust::raw_pointer_cast(values_gpu.data()));
EXPECT_EQ(status, cudaSuccess)
<< "AppendPagedKVCacheDecode kernel launch failed, error message: "
<< cudaGetErrorString(status);
}
}
thrust::host_vector<T> k_data_gpu_h(k_data_gpu), v_data_gpu_h(v_data_gpu);
size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0;
bool nan_detected = false;
for (size_t i = 0; i < k_data_cpu.size(); ++i) {
if (std::isnan(float(k_data_gpu_h[i]))) {
nan_detected = true;
}
num_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(k_data_cpu[i]), float(k_data_gpu_h[i]), 1e-3, 1e-3));
}
for (size_t i = 0; i < v_data_cpu.size(); ++i) {
if (std::isnan(float(v_data_gpu_h[i]))) {
nan_detected = true;
}
num_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(v_data_cpu[i]), float(v_data_gpu_h[i]), 1e-3, 1e-3));
}
float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) /
float(k_data_cpu.size() + v_data_cpu.size());
std::cout << "kv_layout=" << QKVLayoutToString(kv_layout) << ", page_size=" << page_size
<< ", batch_size=" << batch_size << ", num_heads=" << num_heads
<< ", head_dim=" << head_dim << ", result_accuracy=" << result_accuracy << std::endl;
EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed.";
EXPECT_EQ(nan_detected, false) << "Nan detected in the result.";
}
template <typename T>
void TestAppendPagedKVKernelCorrectness() {
for (size_t page_size : {1, 3, 7, 17}) {
for (size_t batch_size : {1, 2, 3, 5, 7, 23, 79, 91}) {
for (size_t num_heads : {32}) {
for (QKVLayout kv_layout : {QKVLayout::kNHD, QKVLayout::kHND}) {
for (size_t head_dim : {64, 128, 256}) {
_TestAppendPagedKVKernelCorrectness<T>(page_size, batch_size, num_heads, head_dim,
kv_layout);
}
}
}
}
}
}
TEST(FlashInferCorrectnessTest, AppendPagedKVKernelCorrectnessTestFP16) {
TestAppendPagedKVKernelCorrectness<half>();
}
TEST(FlashInferCorrectnessTest, AppendPagedKVKernelCorrectnessTestFP32) {
TestAppendPagedKVKernelCorrectness<float>();
}
#ifdef FLASHINFER_ENABLE_BF16
TEST(FlashInferCorrectnessTest, AppendPagedKVKernelCorrectnessTestBF16) {
TestAppendPagedKVKernelCorrectness<__nv_bfloat16>();
}
#endif
#ifdef FLASHINFER_ENABLE_FP8_E4M3
TEST(FlashInferCorrectnessTest, AppendPagedKVKernelCorrectnessTestE4M3) {
TestAppendPagedKVKernelCorrectness<__nv_fp8_e4m3>();
}
#endif
#ifdef FLASHINFER_ENABLE_FP8_E5M2
TEST(FlashInferCorrectnessTest, AppendPagedKVKernelCorrectnessTestE5M2) {
TestAppendPagedKVKernelCorrectness<__nv_fp8_e5m2>();
}
#endif