-
Notifications
You must be signed in to change notification settings - Fork 11
/
export_ncnn.cpp
43 lines (41 loc) · 1.45 KB
/
export_ncnn.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include <fstream>
#include <iostream>
#include <kernels/export-ncnn/kernels.h>
int main(int argc, char **argv) {
if (argc != 3 && argc != 4) {
std::cerr
<< "Usage: ./export_ncnn <input path> <output prefix> [<weight_dtype>]"
<< std::endl;
return 1;
}
if (std::ifstream ifs(argv[1]); !ifs.good()) {
std::cerr << "Failed to open " << argv[1] << std::endl;
std::cerr
<< "Usage: ./export_ncnn <input path> <output prefix> [<weight_dtype>]"
<< std::endl;
return 1;
}
rwkv::DType weight_dtype;
if (argc == 3) {
std::cout
<< "Using fp16 weight dtype... You can also specify weight dtype by "
"adding a third argument. For example, ./export_ncnn <input path> "
"<output prefix> int8, which generates a faster and smaller model."
<< std::endl;
weight_dtype = rwkv::DType::kFloat16;
} else {
std::string weight_dtype_str(argv[3]);
if (weight_dtype_str == "int4" || weight_dtype_str == "i4") {
weight_dtype = rwkv::DType::kInt4;
} else if (weight_dtype_str == "int8" || weight_dtype_str == "i8") {
weight_dtype = rwkv::DType::kInt8;
} else if (weight_dtype_str == "fp16") {
weight_dtype = rwkv::DType::kFloat16;
} else {
RV_UNIMPLEMENTED() << "Only int4, int8, and fp16 are supported. But got "
<< weight_dtype_str << ".";
}
}
rwkv::ncnnmeta::ExportModel(argv[1], weight_dtype, argv[2]);
return 0;
}