diff --git a/CMakeLists.txt b/CMakeLists.txt index 84679f5..db5f1ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,15 @@ if (FR_ENABLE_NCNN) endif() if (FR_ENABLE_ONNX) + option(protobuf_BUILD_TESTS "Build tests" OFF) + FetchContent_Declare( + protobuf + GIT_REPOSITORY https://github.com/protocolbuffers/protobuf + GIT_TAG v3.20.1 + SYSTEM + SOURCE_SUBDIR cmake + ) + FetchContent_MakeAvailable(protobuf) FetchContent_Declare( onnx GIT_REPOSITORY https://github.com/onnx/onnx @@ -97,7 +106,7 @@ if (FR_ENABLE_ONNX) # ) # FetchContent_MakeAvailable(onnxruntime) set(onnx_kernel_srcs - kernels/export-ncnn/kernels.cpp + kernels/export-onnx/kernels.cpp ) endif() diff --git a/export_onnx.cpp b/export_onnx.cpp index 88e71aa..6525217 100644 --- a/export_onnx.cpp +++ b/export_onnx.cpp @@ -6,14 +6,14 @@ int main(int argc, char **argv) { if (argc != 3) { std::cerr - << "Usage: ./export_ncnn " + << "Usage: " << argv[0] << " " << std::endl; return 1; } if (std::ifstream ifs(argv[1]); !ifs.good()) { std::cerr << "Failed to open " << argv[1] << std::endl; std::cerr - << "Usage: ./export_ncnn " + << "Usage: " << argv[0] << " " << std::endl; return 1; } diff --git a/kernels/export-ncnn/kernels.cpp b/kernels/export-ncnn/kernels.cpp index 5ed3493..9f79f4e 100644 --- a/kernels/export-ncnn/kernels.cpp +++ b/kernels/export-ncnn/kernels.cpp @@ -105,15 +105,18 @@ void ExportModel(const std::string &input_path, DType weight_dtype, const std::string &output_prefix) { RV_CHECK(weight_dtype == DType::kFloat16 || weight_dtype == DType::kInt8 || weight_dtype == DType::kInt4); + default_dispatch_device() = Device::kNCNNMeta; - rwkv::ncnnmeta::init(weight_dtype, output_prefix + ".bin", - output_prefix + ".param", output_prefix + ".config"); + init(weight_dtype, output_prefix + ".bin", + output_prefix + ".param", output_prefix + ".config"); // NOTE: fp32 here is just a placeholder. The dtype used by ncnn is determined // by the weight_dtype parameter. - rwkv::Model model(input_path, "export-ncnn fp32"); + Model model(input_path, "export-ncnn fp32"); model.Run(0); - rwkv::ncnnmeta::destroy(model); + destroy(model); + + default_dispatch_device() = std::nullopt; } void append_data_to_bin_file(const Tensor &tensor, bool write_tag) { diff --git a/kernels/export-onnx/kernels.cpp b/kernels/export-onnx/kernels.cpp index 0ec6c54..149d134 100644 --- a/kernels/export-onnx/kernels.cpp +++ b/kernels/export-onnx/kernels.cpp @@ -55,16 +55,27 @@ ModelProto destory() { void ExportModel(const std::string &input_path, const std::string &output_path) { - - // NOTE: fp32 here is just a placeholder. The dtype used by ncnn is determined - // by the weight_dtype parameter. + default_dispatch_device() = Device::kONNXMeta; Model model(input_path, "export-onnx fp32"); model.Run(0); + default_dispatch_device() = std::nullopt; ModelProto model_proto = destory(); // save model_proto to output_path std::ofstream ofs(output_path, std::ios::binary); RV_CHECK(ofs.good()); model_proto.SerializeToOstream(&ofs); + { + std::ofstream config_file(output_path + ".config"); + config_file << "version: " << model.version() << std::endl; + config_file << "head_size: " << model.head_size() << std::endl; + config_file << "n_layer: " << model.n_layer() << std::endl; + config_file << "n_embd: " << model.n_embd() << std::endl; + config_file << "n_att: " << model.n_att() << std::endl; + config_file << "n_ffn: " << model.n_ffn() << std::endl; + std::string kOnnxImplVersion = "1"; + config_file << "onnx_impl_version: " << kOnnxImplVersion << std::endl; + config_file.close(); + } } DType _dtype = DType::kFloat32; @@ -74,10 +85,12 @@ int fr_dtype_to_onnx_dtype(DType fr_dtype) { return TensorProto::FLOAT; } else if (fr_dtype == DType::kFloat16) { return TensorProto::FLOAT16; + } else if (fr_dtype == DType::kInt32) { + return TensorProto::INT32; } else if (fr_dtype == DType::kInt64) { return TensorProto::INT64; } else { - RV_UNIMPLEMENTED(); + RV_UNIMPLEMENTED() << "Unsupported dtype: " << dtype_to_string(fr_dtype); } } @@ -118,6 +131,7 @@ Tensor possible_initializer(const Tensor &x) { NodeProto *node = new_node(); node->set_name(std::to_string(unique_op_id())); node->set_op_type("Constant"); + std::cout << "add constant " << output.name << std::endl; node->add_output(output.name); node->add_attribute()->set_name("value"); node->mutable_attribute(0)->set_type(onnx::AttributeProto::TENSOR); @@ -130,6 +144,18 @@ Tensor possible_initializer(const Tensor &x) { return output; } +Tensor constant(const Tensor &x) { + RV_CHECK(x.device() == Device::kCPU); + return possible_initializer(x); +} + +Tensor constant(const std::vector &x) { + Tensor x_t = Tensor::FromPtr(const_cast(x.data()), + {static_cast(x.size())}, DType::kInt32, + Device::kCPU); + return constant(x_t); +} + Tensor gather(const Tensor &x, const Tensor &index) { RV_CHECK(x.shape().size() == 2); RV_CHECK(index.shape().size() == 0); @@ -152,15 +178,66 @@ Tensor layernorm(const Tensor &x, const Tensor &_weight, const Tensor &_bias) { node->add_input(weight.name); node->add_input(bias.name); node->add_output(output.name); + return output; +} + +Tensor concat(const std::vector &xs, int axis) { + RV_CHECK(xs.size() > 0); + RV_CHECK(axis == 1); + std::vector x_shapes; + for (auto &x : xs) { + x_shapes.push_back(x.shape()); + } + auto output = Tensor::Empty(shape::concat(x_shapes, axis), xs[0].dtype(), + xs[0].device()); + NodeProto *node = new_node(); + node->set_op_type("Concat"); + for (auto &x : xs) { + node->add_input(x.name); + } + node->add_output(output.name); node->add_attribute()->set_name("axis"); - node->mutable_attribute(0)->set_i(1); + node->mutable_attribute(0)->set_i(axis); node->mutable_attribute(0)->set_type(onnx::AttributeProto::INT); - node->add_attribute()->set_name("epsilon"); - node->mutable_attribute(1)->set_f(1e-5f); - node->mutable_attribute(1)->set_type(onnx::AttributeProto::FLOAT); return output; } +Tensor slice(const Tensor &x, const std::vector &starts, + const std::vector &ends, const std::vector &axes) { + RV_CHECK(axes.size() == starts.size()); + RV_CHECK(axes.size() == ends.size()); + auto starts_t = constant(starts); + auto ends_t = constant(ends); + auto axes_t = constant(axes); + auto output = Tensor::Empty(shape::slice(x.shape(), starts, ends, axes), + x.dtype(), x.device()); + NodeProto *node = new_node(); + node->set_op_type("Slice"); + node->add_input(x.name); + node->add_input(starts_t.name); + node->add_input(ends_t.name); + node->add_input(axes_t.name); + node->add_output(output.name); + return output; +} + +Tensor groupnorm(const Tensor &x, int num_groups, const Tensor &_weight, + const Tensor &_bias) { + auto weight = possible_initializer(_weight); + auto bias = possible_initializer(_bias); + int len = x.shape()[1]; + RV_CHECK(len % num_groups == 0); + int group_size = len / num_groups; + std::vector ln_outs; + for (int i = 0; i < num_groups; i++) { + auto x_slice = slice(x, {i * group_size}, {(i + 1) * group_size}, {1}); + auto w_slice = slice(weight, {i * group_size}, {(i + 1) * group_size}, {0}); + auto b_slice = slice(bias, {i * group_size}, {(i + 1) * group_size}, {0}); + ln_outs.push_back(layernorm(x_slice, w_slice, b_slice)); + } + return concat(ln_outs, 1); +} + Tensor matmul(const Tensor &_x, const Tensor &_y) { auto x = possible_initializer(_x); auto y = possible_initializer(_y); @@ -236,14 +313,14 @@ Tensor sigmoid(const Tensor &x) { } Tensor reshape(const Tensor &x, const Shape &shape) { + Tensor shape_cpu_tensor = Tensor::FromPtr(const_cast(shape.data()), + {static_cast(shape.size())}, + DType::kInt64, Device::kCPU); + auto shape_tensor = constant(shape_cpu_tensor); Tensor output = Tensor::Empty(shape, x.dtype(), x.device()); NodeProto *node = new_node(); node->set_op_type("Reshape"); node->add_input(x.name); - Tensor shape_cpu_tensor = Tensor::FromPtr(const_cast(shape.data()), - {static_cast(shape.size())}, - DType::kInt64, Device::kCPU); - auto shape_tensor = possible_initializer(shape_cpu_tensor); node->add_input(shape_tensor.name); node->add_output(output.name); return output; @@ -254,7 +331,6 @@ Tensor mark_as_output(const Tensor &x, const std::string &name) { output->set_name(name); output->mutable_type()->mutable_tensor_type()->set_elem_type( fr_dtype_to_onnx_dtype(_dtype)); - std::cout << "output name: " << name << ", dim: " << x.shape() << std::endl; for (auto dim : x.shape()) { output->mutable_type() ->mutable_tensor_type() diff --git a/kernels/kernels.h b/kernels/kernels.h index 8f8c28e..9dfcd64 100644 --- a/kernels/kernels.h +++ b/kernels/kernels.h @@ -35,15 +35,15 @@ att_one_v5(const Tensor &x, const Tensor &sx, const Tensor &s, inline std::tuple att_one_v5_1(const Tensor &x, const Tensor &sx, const Tensor &s, - const Tensor &ln_w, const Tensor &ln_b, const Tensor &lx_w, - const Tensor &lx_b, const Tensor &k_mix, const Tensor &v_mix, - const Tensor &r_mix, const Tensor &g_mix, const Tensor &t_decay, const Tensor &t_first, - const Tensor &kw, const Tensor &vw, const Tensor &rw, const Tensor& gw, - const Tensor &ow) { + const Tensor &ln_w, const Tensor &ln_b, const Tensor &lx_w, + const Tensor &lx_b, const Tensor &k_mix, const Tensor &v_mix, + const Tensor &r_mix, const Tensor &g_mix, const Tensor &t_decay, + const Tensor &t_first, const Tensor &kw, const Tensor &vw, + const Tensor &rw, const Tensor &gw, const Tensor &ow) { auto tmp = KernelRegistry::Instance().Get( "att_one_v5_1", x.device()); - return tmp(x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, - t_first, kw, vw, rw, gw, ow); + return tmp(x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, + t_decay, t_first, kw, vw, rw, gw, ow); } // def cuda_ffn_one_fp16(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, @@ -95,12 +95,12 @@ inline Tensor matmul(const Tensor &a, const Tensor &b) { inline Tensor add(const Tensor &x, const Tensor &y) { // TODO: global device return KernelRegistry::Instance().Get( - "add", Device::kNCNNMeta)(x, y); + "add", default_dispatch_device().value_or(x.device()))(x, y); } inline Tensor sub(float x, const Tensor &y) { return KernelRegistry::Instance().Get( - "rsub_scalar", Device::kNCNNMeta)(x, y); + "rsub_scalar", default_dispatch_device().value_or(y.device()))(x, y); } inline Tensor sub(const Tensor &x, const Tensor &y) { @@ -109,8 +109,8 @@ inline Tensor sub(const Tensor &x, const Tensor &y) { } inline Tensor mul(const Tensor &x, const Tensor &y) { - return KernelRegistry::Instance().Get("mul", x.device())(x, - y); + return KernelRegistry::Instance().Get( + "mul", default_dispatch_device().value_or(x.device()))(x, y); } inline Tensor div(const Tensor &x, const Tensor &y) { @@ -138,18 +138,16 @@ inline Tensor maximum(const Tensor &x, const Tensor &y) { } inline Tensor softmax(const Tensor &x, float temperature) { - return KernelRegistry::Instance().Get("softmax", - x.device())(x, temperature); + return KernelRegistry::Instance().Get( + "softmax", x.device())(x, temperature); } inline Tensor reshape(const Tensor &x, const Shape &shape) { - return KernelRegistry::Instance().Get("reshape", - x.device())(x, shape); + return KernelRegistry::Instance().Get( + "reshape", x.device())(x, shape); } -inline Tensor flatten(const Tensor &x) { - return reshape(x, {x.numel()}); -} +inline Tensor flatten(const Tensor &x) { return reshape(x, {x.numel()}); } inline Tensor unsqueeze(const Tensor &x, int dim) { auto new_shape = x.shape(); @@ -165,10 +163,9 @@ inline Tensor mark_as_output(const Tensor &x, const std::string &name) { class Model; inline void init_model(Model *model, Device device, const std::string &path, - const std::string &strategy, const std::any& extra) { - KernelRegistry::Instance() - .Get( - "init_model", device)(model, device, path, strategy, extra); + const std::string &strategy, const std::any &extra) { + KernelRegistry::Instance().Get("init_model", device)( + model, device, path, strategy, extra); } inline Tensor ModelForward(const Model *model, Device device, int id, diff --git a/kernels/shape/shape_inference.cpp b/kernels/shape/shape_inference.cpp index bca3448..d5c11a9 100644 --- a/kernels/shape/shape_inference.cpp +++ b/kernels/shape/shape_inference.cpp @@ -46,7 +46,8 @@ Shape matmul(const Shape &x, const Shape &y) { Shape broadcast_binary(const Shape &x, const Shape &y) { auto nrank = std::max(x.size(), y.size()); Shape output_shape(nrank); - for (int i = nrank - 1, x_idx = x.size() - 1, y_idx = y.size() - 1; i >= 0; i--, x_idx--, y_idx--) { + for (int i = nrank - 1, x_idx = x.size() - 1, y_idx = y.size() - 1; i >= 0; + i--, x_idx--, y_idx--) { if (x_idx < 0) { output_shape[i] = y[y_idx]; } else if (y_idx < 0) { @@ -63,5 +64,37 @@ Shape broadcast_binary(const Shape &x, const Shape &y) { } return output_shape; } + +Shape concat(const std::vector &xs, int axis) { + RV_CHECK(xs.size() > 0); + RV_CHECK(axis >= 0 && axis < xs[0].size()); + Shape output_shape = xs[0]; + for (int i = 1; i < xs.size(); i++) { + RV_CHECK(xs[i].size() == output_shape.size()); + for (int j = 0; j < xs[i].size(); j++) { + if (j == axis) { + output_shape[j] += xs[i][j]; + } else { + RV_CHECK(xs[i][j] == output_shape[j]); + } + } + } + return output_shape; +} + +Shape slice(const Shape &x, const std::vector &starts, + const std::vector &ends, const std::vector &axes) { + RV_CHECK(starts.size() == ends.size()); + RV_CHECK(starts.size() == axes.size()); + Shape output_shape = x; + for (int i = 0; i < starts.size(); i++) { + RV_CHECK(starts[i] >= 0 && starts[i] < x[axes[i]]); + RV_CHECK(ends[i] >= 0 && ends[i] <= x[axes[i]]); + RV_CHECK(starts[i] < ends[i]); + output_shape[axes[i]] = ends[i] - starts[i]; + } + return output_shape; +} + } // namespace shape } // namespace rwkv diff --git a/kernels/shape/shape_inference.h b/kernels/shape/shape_inference.h index e35b342..c5ba6fb 100644 --- a/kernels/shape/shape_inference.h +++ b/kernels/shape/shape_inference.h @@ -6,5 +6,10 @@ Shape matmul(const Shape &x, const Shape &y); Shape broadcast_binary(const Shape &x, const Shape &y); +Shape concat(const std::vector &xs, int axis); + +Shape slice(const Shape &x, const std::vector &starts, + const std::vector &ends, const std::vector &axes); + } // namespace shape } // namespace rwkv diff --git a/tensor.cpp b/tensor.cpp index 76d1982..d75ebc6 100644 --- a/tensor.cpp +++ b/tensor.cpp @@ -12,6 +12,11 @@ namespace rwkv { +std::optional& default_dispatch_device() { + static std::optional _default_dispatch_device = std::nullopt; + return _default_dispatch_device; +} + // operator<< for Shape std::ostream &operator<<(std::ostream &os, const Shape &shape) { os << "("; diff --git a/tensor.h b/tensor.h index a183c34..7ac1d31 100644 --- a/tensor.h +++ b/tensor.h @@ -25,6 +25,7 @@ enum class DType { kInt8, kFloat16, kFloat32, + kInt32, kInt64, }; using float16 = half_float::half; @@ -36,6 +37,8 @@ enum class Device { kNCNN, kONNX, }; +std::optional& default_dispatch_device(); + template inline const DType dtype_v = DType::kFloat32; template <> inline const DType dtype_v = DType::kFloat16; #ifdef FR_ENABLE_CUDA @@ -54,6 +57,10 @@ inline std::string dtype_to_string(DType dtype) { return "fp32"; } else if (dtype == DType::kFloat16) { return "fp16"; + } else if (dtype == DType::kInt32) { + return "int32"; + } else if (dtype == DType::kInt64) { + return "int64"; } else if (dtype == DType::kInt8) { return "int8"; } else if (dtype == DType::kInt4) { @@ -79,10 +86,12 @@ inline int32_t elem_size(DType dtype) { return 2; case DType::kFloat32: return 4; + case DType::kInt32: + return 4; case DType::kInt64: return 8; default: - RV_UNIMPLEMENTED(); + RV_UNIMPLEMENTED() << "Unsupported dtype: " << dtype_to_string(dtype); } }