Skip to content

Commit

Permalink
export onnx
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <daquexian566@gmail.com>
  • Loading branch information
daquexian committed Oct 10, 2023
1 parent 2680cd9 commit 85544f8
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 44 deletions.
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ if (FR_ENABLE_NCNN)
endif()

if (FR_ENABLE_ONNX)
option(protobuf_BUILD_TESTS "Build tests" OFF)
FetchContent_Declare(
protobuf
GIT_REPOSITORY https://github.com/protocolbuffers/protobuf
GIT_TAG v3.20.1
SYSTEM
SOURCE_SUBDIR cmake
)
FetchContent_MakeAvailable(protobuf)
FetchContent_Declare(
onnx
GIT_REPOSITORY https://github.com/onnx/onnx
Expand All @@ -97,7 +106,7 @@ if (FR_ENABLE_ONNX)
# )
# FetchContent_MakeAvailable(onnxruntime)
set(onnx_kernel_srcs
kernels/export-ncnn/kernels.cpp
kernels/export-onnx/kernels.cpp
)
endif()

Expand Down
4 changes: 2 additions & 2 deletions export_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
int main(int argc, char **argv) {
if (argc != 3) {
std::cerr
<< "Usage: ./export_ncnn <input path> <output prefix>"
<< "Usage: " << argv[0] << " <input path> <output prefix>"
<< std::endl;
return 1;
}
if (std::ifstream ifs(argv[1]); !ifs.good()) {
std::cerr << "Failed to open " << argv[1] << std::endl;
std::cerr
<< "Usage: ./export_ncnn <input path> <output prefix>"
<< "Usage: " << argv[0] << " <input path> <output prefix>"
<< std::endl;
return 1;
}
Expand Down
11 changes: 7 additions & 4 deletions kernels/export-ncnn/kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,18 @@ void ExportModel(const std::string &input_path, DType weight_dtype,
const std::string &output_prefix) {
RV_CHECK(weight_dtype == DType::kFloat16 || weight_dtype == DType::kInt8 ||
weight_dtype == DType::kInt4);
default_dispatch_device() = Device::kNCNNMeta;

rwkv::ncnnmeta::init(weight_dtype, output_prefix + ".bin",
output_prefix + ".param", output_prefix + ".config");
init(weight_dtype, output_prefix + ".bin",
output_prefix + ".param", output_prefix + ".config");

// NOTE: fp32 here is just a placeholder. The dtype used by ncnn is determined
// by the weight_dtype parameter.
rwkv::Model model(input_path, "export-ncnn fp32");
Model model(input_path, "export-ncnn fp32");
model.Run(0);
rwkv::ncnnmeta::destroy(model);
destroy(model);

default_dispatch_device() = std::nullopt;
}

void append_data_to_bin_file(const Tensor &tensor, bool write_tag) {
Expand Down
102 changes: 89 additions & 13 deletions kernels/export-onnx/kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,27 @@ ModelProto destory() {

void ExportModel(const std::string &input_path,
const std::string &output_path) {

// NOTE: fp32 here is just a placeholder. The dtype used by ncnn is determined
// by the weight_dtype parameter.
default_dispatch_device() = Device::kONNXMeta;
Model model(input_path, "export-onnx fp32");
model.Run(0);
default_dispatch_device() = std::nullopt;
ModelProto model_proto = destory();
// save model_proto to output_path
std::ofstream ofs(output_path, std::ios::binary);
RV_CHECK(ofs.good());
model_proto.SerializeToOstream(&ofs);
{
std::ofstream config_file(output_path + ".config");
config_file << "version: " << model.version() << std::endl;
config_file << "head_size: " << model.head_size() << std::endl;
config_file << "n_layer: " << model.n_layer() << std::endl;
config_file << "n_embd: " << model.n_embd() << std::endl;
config_file << "n_att: " << model.n_att() << std::endl;
config_file << "n_ffn: " << model.n_ffn() << std::endl;
std::string kOnnxImplVersion = "1";
config_file << "onnx_impl_version: " << kOnnxImplVersion << std::endl;
config_file.close();
}
}

DType _dtype = DType::kFloat32;
Expand All @@ -74,10 +85,12 @@ int fr_dtype_to_onnx_dtype(DType fr_dtype) {
return TensorProto::FLOAT;
} else if (fr_dtype == DType::kFloat16) {
return TensorProto::FLOAT16;
} else if (fr_dtype == DType::kInt32) {
return TensorProto::INT32;
} else if (fr_dtype == DType::kInt64) {
return TensorProto::INT64;
} else {
RV_UNIMPLEMENTED();
RV_UNIMPLEMENTED() << "Unsupported dtype: " << dtype_to_string(fr_dtype);
}
}

Expand Down Expand Up @@ -118,6 +131,7 @@ Tensor possible_initializer(const Tensor &x) {
NodeProto *node = new_node();
node->set_name(std::to_string(unique_op_id()));
node->set_op_type("Constant");
std::cout << "add constant " << output.name << std::endl;
node->add_output(output.name);
node->add_attribute()->set_name("value");
node->mutable_attribute(0)->set_type(onnx::AttributeProto::TENSOR);
Expand All @@ -130,6 +144,18 @@ Tensor possible_initializer(const Tensor &x) {
return output;
}

Tensor constant(const Tensor &x) {
RV_CHECK(x.device() == Device::kCPU);
return possible_initializer(x);
}

Tensor constant(const std::vector<int> &x) {
Tensor x_t = Tensor::FromPtr(const_cast<int *>(x.data()),
{static_cast<long>(x.size())}, DType::kInt32,
Device::kCPU);
return constant(x_t);
}

Tensor gather(const Tensor &x, const Tensor &index) {
RV_CHECK(x.shape().size() == 2);
RV_CHECK(index.shape().size() == 0);
Expand All @@ -152,15 +178,66 @@ Tensor layernorm(const Tensor &x, const Tensor &_weight, const Tensor &_bias) {
node->add_input(weight.name);
node->add_input(bias.name);
node->add_output(output.name);
return output;
}

Tensor concat(const std::vector<Tensor> &xs, int axis) {
RV_CHECK(xs.size() > 0);
RV_CHECK(axis == 1);
std::vector<Shape> x_shapes;
for (auto &x : xs) {
x_shapes.push_back(x.shape());
}
auto output = Tensor::Empty(shape::concat(x_shapes, axis), xs[0].dtype(),
xs[0].device());
NodeProto *node = new_node();
node->set_op_type("Concat");
for (auto &x : xs) {
node->add_input(x.name);
}
node->add_output(output.name);
node->add_attribute()->set_name("axis");
node->mutable_attribute(0)->set_i(1);
node->mutable_attribute(0)->set_i(axis);
node->mutable_attribute(0)->set_type(onnx::AttributeProto::INT);
node->add_attribute()->set_name("epsilon");
node->mutable_attribute(1)->set_f(1e-5f);
node->mutable_attribute(1)->set_type(onnx::AttributeProto::FLOAT);
return output;
}

Tensor slice(const Tensor &x, const std::vector<int> &starts,
const std::vector<int> &ends, const std::vector<int> &axes) {
RV_CHECK(axes.size() == starts.size());
RV_CHECK(axes.size() == ends.size());
auto starts_t = constant(starts);
auto ends_t = constant(ends);
auto axes_t = constant(axes);
auto output = Tensor::Empty(shape::slice(x.shape(), starts, ends, axes),
x.dtype(), x.device());
NodeProto *node = new_node();
node->set_op_type("Slice");
node->add_input(x.name);
node->add_input(starts_t.name);
node->add_input(ends_t.name);
node->add_input(axes_t.name);
node->add_output(output.name);
return output;
}

Tensor groupnorm(const Tensor &x, int num_groups, const Tensor &_weight,
const Tensor &_bias) {
auto weight = possible_initializer(_weight);
auto bias = possible_initializer(_bias);
int len = x.shape()[1];
RV_CHECK(len % num_groups == 0);
int group_size = len / num_groups;
std::vector<Tensor> ln_outs;
for (int i = 0; i < num_groups; i++) {
auto x_slice = slice(x, {i * group_size}, {(i + 1) * group_size}, {1});
auto w_slice = slice(weight, {i * group_size}, {(i + 1) * group_size}, {0});
auto b_slice = slice(bias, {i * group_size}, {(i + 1) * group_size}, {0});
ln_outs.push_back(layernorm(x_slice, w_slice, b_slice));
}
return concat(ln_outs, 1);
}

Tensor matmul(const Tensor &_x, const Tensor &_y) {
auto x = possible_initializer(_x);
auto y = possible_initializer(_y);
Expand Down Expand Up @@ -236,14 +313,14 @@ Tensor sigmoid(const Tensor &x) {
}

Tensor reshape(const Tensor &x, const Shape &shape) {
Tensor shape_cpu_tensor = Tensor::FromPtr(const_cast<int64_t *>(shape.data()),
{static_cast<long>(shape.size())},
DType::kInt64, Device::kCPU);
auto shape_tensor = constant(shape_cpu_tensor);
Tensor output = Tensor::Empty(shape, x.dtype(), x.device());
NodeProto *node = new_node();
node->set_op_type("Reshape");
node->add_input(x.name);
Tensor shape_cpu_tensor = Tensor::FromPtr(const_cast<int64_t *>(shape.data()),
{static_cast<long>(shape.size())},
DType::kInt64, Device::kCPU);
auto shape_tensor = possible_initializer(shape_cpu_tensor);
node->add_input(shape_tensor.name);
node->add_output(output.name);
return output;
Expand All @@ -254,7 +331,6 @@ Tensor mark_as_output(const Tensor &x, const std::string &name) {
output->set_name(name);
output->mutable_type()->mutable_tensor_type()->set_elem_type(
fr_dtype_to_onnx_dtype(_dtype));
std::cout << "output name: " << name << ", dim: " << x.shape() << std::endl;
for (auto dim : x.shape()) {
output->mutable_type()
->mutable_tensor_type()
Expand Down
41 changes: 19 additions & 22 deletions kernels/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ att_one_v5(const Tensor &x, const Tensor &sx, const Tensor &s,

inline std::tuple<Tensor, Tensor, Tensor>
att_one_v5_1(const Tensor &x, const Tensor &sx, const Tensor &s,
const Tensor &ln_w, const Tensor &ln_b, const Tensor &lx_w,
const Tensor &lx_b, const Tensor &k_mix, const Tensor &v_mix,
const Tensor &r_mix, const Tensor &g_mix, const Tensor &t_decay, const Tensor &t_first,
const Tensor &kw, const Tensor &vw, const Tensor &rw, const Tensor& gw,
const Tensor &ow) {
const Tensor &ln_w, const Tensor &ln_b, const Tensor &lx_w,
const Tensor &lx_b, const Tensor &k_mix, const Tensor &v_mix,
const Tensor &r_mix, const Tensor &g_mix, const Tensor &t_decay,
const Tensor &t_first, const Tensor &kw, const Tensor &vw,
const Tensor &rw, const Tensor &gw, const Tensor &ow) {
auto tmp = KernelRegistry::Instance().Get<decltype(att_one_v5_1) *>(
"att_one_v5_1", x.device());
return tmp(x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay,
t_first, kw, vw, rw, gw, ow);
return tmp(x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix,
t_decay, t_first, kw, vw, rw, gw, ow);
}

// def cuda_ffn_one_fp16(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw,
Expand Down Expand Up @@ -95,12 +95,12 @@ inline Tensor matmul(const Tensor &a, const Tensor &b) {
inline Tensor add(const Tensor &x, const Tensor &y) {
// TODO: global device
return KernelRegistry::Instance().Get<decltype(add) *>(
"add", Device::kNCNNMeta)(x, y);
"add", default_dispatch_device().value_or(x.device()))(x, y);
}

inline Tensor sub(float x, const Tensor &y) {
return KernelRegistry::Instance().Get<Tensor (*)(float, const Tensor &)>(
"rsub_scalar", Device::kNCNNMeta)(x, y);
"rsub_scalar", default_dispatch_device().value_or(y.device()))(x, y);
}

inline Tensor sub(const Tensor &x, const Tensor &y) {
Expand All @@ -109,8 +109,8 @@ inline Tensor sub(const Tensor &x, const Tensor &y) {
}

inline Tensor mul(const Tensor &x, const Tensor &y) {
return KernelRegistry::Instance().Get<decltype(mul) *>("mul", x.device())(x,
y);
return KernelRegistry::Instance().Get<decltype(mul) *>(
"mul", default_dispatch_device().value_or(x.device()))(x, y);
}

inline Tensor div(const Tensor &x, const Tensor &y) {
Expand Down Expand Up @@ -138,18 +138,16 @@ inline Tensor maximum(const Tensor &x, const Tensor &y) {
}

inline Tensor softmax(const Tensor &x, float temperature) {
return KernelRegistry::Instance().Get<decltype(softmax) *>("softmax",
x.device())(x, temperature);
return KernelRegistry::Instance().Get<decltype(softmax) *>(
"softmax", x.device())(x, temperature);
}

inline Tensor reshape(const Tensor &x, const Shape &shape) {
return KernelRegistry::Instance().Get<decltype(reshape) *>("reshape",
x.device())(x, shape);
return KernelRegistry::Instance().Get<decltype(reshape) *>(
"reshape", x.device())(x, shape);
}

inline Tensor flatten(const Tensor &x) {
return reshape(x, {x.numel()});
}
inline Tensor flatten(const Tensor &x) { return reshape(x, {x.numel()}); }

inline Tensor unsqueeze(const Tensor &x, int dim) {
auto new_shape = x.shape();
Expand All @@ -165,10 +163,9 @@ inline Tensor mark_as_output(const Tensor &x, const std::string &name) {
class Model;

inline void init_model(Model *model, Device device, const std::string &path,
const std::string &strategy, const std::any& extra) {
KernelRegistry::Instance()
.Get<decltype(init_model)*>(
"init_model", device)(model, device, path, strategy, extra);
const std::string &strategy, const std::any &extra) {
KernelRegistry::Instance().Get<decltype(init_model) *>("init_model", device)(
model, device, path, strategy, extra);
}

inline Tensor ModelForward(const Model *model, Device device, int id,
Expand Down
35 changes: 34 additions & 1 deletion kernels/shape/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ Shape matmul(const Shape &x, const Shape &y) {
Shape broadcast_binary(const Shape &x, const Shape &y) {
auto nrank = std::max(x.size(), y.size());
Shape output_shape(nrank);
for (int i = nrank - 1, x_idx = x.size() - 1, y_idx = y.size() - 1; i >= 0; i--, x_idx--, y_idx--) {
for (int i = nrank - 1, x_idx = x.size() - 1, y_idx = y.size() - 1; i >= 0;
i--, x_idx--, y_idx--) {
if (x_idx < 0) {
output_shape[i] = y[y_idx];
} else if (y_idx < 0) {
Expand All @@ -63,5 +64,37 @@ Shape broadcast_binary(const Shape &x, const Shape &y) {
}
return output_shape;
}

Shape concat(const std::vector<Shape> &xs, int axis) {
RV_CHECK(xs.size() > 0);
RV_CHECK(axis >= 0 && axis < xs[0].size());
Shape output_shape = xs[0];
for (int i = 1; i < xs.size(); i++) {
RV_CHECK(xs[i].size() == output_shape.size());
for (int j = 0; j < xs[i].size(); j++) {
if (j == axis) {
output_shape[j] += xs[i][j];
} else {
RV_CHECK(xs[i][j] == output_shape[j]);
}
}
}
return output_shape;
}

Shape slice(const Shape &x, const std::vector<int> &starts,
const std::vector<int> &ends, const std::vector<int> &axes) {
RV_CHECK(starts.size() == ends.size());
RV_CHECK(starts.size() == axes.size());
Shape output_shape = x;
for (int i = 0; i < starts.size(); i++) {
RV_CHECK(starts[i] >= 0 && starts[i] < x[axes[i]]);
RV_CHECK(ends[i] >= 0 && ends[i] <= x[axes[i]]);
RV_CHECK(starts[i] < ends[i]);
output_shape[axes[i]] = ends[i] - starts[i];
}
return output_shape;
}

} // namespace shape
} // namespace rwkv
5 changes: 5 additions & 0 deletions kernels/shape/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,10 @@ Shape matmul(const Shape &x, const Shape &y);

Shape broadcast_binary(const Shape &x, const Shape &y);

Shape concat(const std::vector<Shape> &xs, int axis);

Shape slice(const Shape &x, const std::vector<int> &starts,
const std::vector<int> &ends, const std::vector<int> &axes);

} // namespace shape
} // namespace rwkv
5 changes: 5 additions & 0 deletions tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

namespace rwkv {

std::optional<Device>& default_dispatch_device() {
static std::optional<Device> _default_dispatch_device = std::nullopt;
return _default_dispatch_device;
}

// operator<< for Shape
std::ostream &operator<<(std::ostream &os, const Shape &shape) {
os << "(";
Expand Down
Loading

0 comments on commit 85544f8

Please sign in to comment.