Skip to content

Commit

Permalink
fix onnx fp16 bug
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <daquexian566@gmail.com>
  • Loading branch information
daquexian committed Oct 13, 2023
1 parent 81d90f4 commit 7ae8645
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 19 deletions.
8 changes: 4 additions & 4 deletions export_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
#include <kernels/export-onnx/kernels.h>

int main(int argc, char **argv) {
if (argc != 3) {
if (argc != 4) {
std::cerr
<< "Usage: " << argv[0] << " <input path> <output prefix>"
<< "Usage: " << argv[0] << " <input path> <output prefix> <dtype>"
<< std::endl;
return 1;
}
if (std::ifstream ifs(argv[1]); !ifs.good()) {
std::cerr << "Failed to open " << argv[1] << std::endl;
std::cerr
<< "Usage: " << argv[0] << " <input path> <output prefix>"
<< "Usage: " << argv[0] << " <input path> <output prefix> <dtype>"
<< std::endl;
return 1;
}
rwkv::onnxmeta::ExportModel(argv[1], argv[2]);
rwkv::onnxmeta::ExportModel(argv[1], argv[2], argv[3]);
return 0;
}

34 changes: 27 additions & 7 deletions kernels/default/model_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,34 @@ Tensor ModelForward(Model *model, Device device, int id) {
model->_embd_weights[0].shape()[0]},
model->weight_dtype(), Device::kCPU);
{
float16 *ptr = embd_weights_cpu.data_ptr<float16>();
for (int i = 0; i < model->_embd_weights.size(); i++) {
for (int j = 0; j < model->_n_embd; j++) {
// embd weights in .fr are always fp16
// *ptr++ = static_cast<float>(
// model->_embd_weights[i].data_ptr<float16>()[j]);
*ptr++ = model->_embd_weights[i].data_ptr<float16>()[j];
auto fr_embd_dtype = model->_embd_weights[0].dtype();
auto weight_dtype = model->weight_dtype();
if (fr_embd_dtype == DType::kFloat16 &&
weight_dtype == DType::kFloat32) {
auto *ptr = embd_weights_cpu.data_ptr<float>();
for (int i = 0; i < model->_embd_weights.size(); i++) {
for (int j = 0; j < model->_n_embd; j++) {
*ptr++ = model->_embd_weights[i].data_ptr<float16>()[j];
}
}
} else if (fr_embd_dtype == DType::kFloat32 &&
weight_dtype == DType::kFloat32) {
auto *ptr = embd_weights_cpu.data_ptr<float>();
for (int i = 0; i < model->_embd_weights.size(); i++) {
for (int j = 0; j < model->_n_embd; j++) {
*ptr++ = model->_embd_weights[i].data_ptr<float>()[j];
}
}
} else if (fr_embd_dtype == DType::kFloat16 &&
weight_dtype == DType::kFloat16) {
auto *ptr = embd_weights_cpu.data_ptr<float16>();
for (int i = 0; i < model->_embd_weights.size(); i++) {
for (int j = 0; j < model->_n_embd; j++) {
*ptr++ = model->_embd_weights[i].data_ptr<float16>()[j];
}
}
} else {
RV_UNIMPLEMENTED();
}
}
if (model->_act_device == Device::kNCNNMeta) {
Expand Down
12 changes: 6 additions & 6 deletions kernels/export-onnx/kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using onnx::NodeProto;
using onnx::TensorProto;
using onnx::ValueInfoProto;

static const int kOpsetVersion = 18;
static const int kOpsetVersion = 17;
static const int kExternalDataThreshold = 1024;
// we do not use opset 17 layernorm by default (even if it is available)
// because it is not supported by NNAPI, CoreML, etc.
Expand Down Expand Up @@ -62,14 +62,14 @@ ModelProto Finish() {
}

void ExportModel(const std::string &input_path,
const std::string &output_path) {
const std::string &output_path, const std::string& dtype) {

default_dispatch_device() = Device::kONNXMeta;
external_data_filename = output_path + ".bin";
external_data_file.open(external_data_filename, std::ios::binary);
RV_CHECK(external_data_file.good());
external_data_offset = 0;
Model model(input_path, "export-onnx fp16");
Model model(input_path, "export-onnx " + dtype);
model.Run(0);
default_dispatch_device() = std::nullopt;
ModelProto model_proto = Finish();
Expand Down Expand Up @@ -312,8 +312,6 @@ Tensor matmul(const Tensor &_x, const Tensor &_y) {
return output;
}

// TODO: add shape inference

#define BROADCAST_BINARY_OP(op, onnx_type) \
Tensor op(const Tensor &_x, const Tensor &_y) { \
auto x = possible_initializer(_x); \
Expand All @@ -336,7 +334,9 @@ BROADCAST_BINARY_OP(maximum, "Max")

Tensor scalar_div(Tensor &x, float y) {
Tensor y_t = constant_scalar(y, x.dtype());
return div(x, y_t);
auto ret = div(x, y_t);
x = ret;
return ret;
}

Tensor rsub_scalar(float x, const Tensor &_y) {
Expand Down
5 changes: 3 additions & 2 deletions kernels/export-onnx/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ namespace rwkv {
namespace onnxmeta {
Tensor add_input(const Shape &shape, DType dtype, const std::string &name);
Tensor possible_initializer(const Tensor &x);
Tensor gather(const Tensor& x, const Tensor& index);
Tensor gather(const Tensor &x, const Tensor &index);

void ExportModel(const std::string &input_path, const std::string &output_path);
void ExportModel(const std::string &input_path, const std::string &output_path,
const std::string &dtype);
} // namespace onnxmeta
} // namespace rwkv
2 changes: 2 additions & 0 deletions kernels/onnx/init_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ void init_model(Model *model, Device device, const std::string &path,
} else {
session_options.SetLogSeverityLevel(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR);
}
// ORT optimization has a bug on rwkv models with layernorm 17
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);
#ifdef __ANDROID__
if (std::getenv("NNAPI") != nullptr) {
uint32_t nnapi_flags = 0;
Expand Down

0 comments on commit 7ae8645

Please sign in to comment.