diff --git a/lite/kernels/xpu/__xpu__bigru_compute.cc b/lite/kernels/xpu/__xpu__bigru_compute.cc index d0122e778d5..08de0d70c45 100644 --- a/lite/kernels/xpu/__xpu__bigru_compute.cc +++ b/lite/kernels/xpu/__xpu__bigru_compute.cc @@ -195,24 +195,26 @@ void XPUBiGRUCompute::MulRun(bool forward) { (bias_guard == nullptr) ? nullptr : reinterpret_cast(bias_guard->addr_); - - int r = xdnn::fc_int16( - ctx.GetRawContext(), /* context */ - false, /* TransA */ - true, /* TransB */ - m, - n, - k, - 1.0f, /* alpha */ - x_matrix.data(), /* A */ - reinterpret_cast(input_max_guard_->addr_), - reinterpret_cast(quant_weight_guard->addr_), /* B */ - reinterpret_cast(weight_max_guard->addr_), - 0.0f, /* beta */ - output.mutable_data(TARGET(kXPU)), /* C */ - reinterpret_cast(mul_output_max_guard_->addr_), - bias_ptr, - xdnn::Activation_t::LINEAR); + int r = xdnn::fc_fusion( + ctx.GetRawContext(), // ctx + x_matrix.data(), // x + reinterpret_cast(quant_weight_guard->addr_), // w + output.mutable_data(TARGET(kXPU)), // y + m, // m + n, // n + k, // k + false, // x_trans + true, // w_trans + reinterpret_cast(input_max_guard_->addr_), // x_maxptr + reinterpret_cast(weight_max_guard->addr_), // w_maxptr + reinterpret_cast(mul_output_max_guard_->addr_), // y_maxptr, + k, // ldx + k, // ldw + n, // ldy + 1.0f, // alpha + 0.0f, // beta + bias_ptr, // bias + xdnn::Activation_t::LINEAR); // act CHECK_EQ(r, 0); *(output.mutable_lod()) = origin_x.lod(); } diff --git a/lite/kernels/xpu/__xpu__fc_compute.cc b/lite/kernels/xpu/__xpu__fc_compute.cc index a70f1dfc587..5574ea36a98 100644 --- a/lite/kernels/xpu/__xpu__fc_compute.cc +++ b/lite/kernels/xpu/__xpu__fc_compute.cc @@ -127,25 +127,29 @@ void XPUFcCompute::Run() { reinterpret_cast(input_max_guard_->addr_)); CHECK_EQ(r, 0); } - r = xdnn::fc_int16( - ctx.GetRawContext(), /* context */ - false, /* TransA */ - true, /* TransB */ - m, /* m */ - n, /* n */ - k, /* k */ - 1.0f, /* alpha */ - param.input->data(), /* A */ + r = xdnn::fc_fusion( + ctx.GetRawContext(), // ctx + param.input->data(), // x + reinterpret_cast(quant_weight_guard_->addr_), // w + param.output->mutable_data(TARGET(kXPU)), // y + m, // m + n, // n + k, // k + false, // x_trans + true, // w_trans (input_max == nullptr) ? reinterpret_cast(input_max_guard_->addr_) - : input_max, - reinterpret_cast(quant_weight_guard_->addr_), /* B */ - reinterpret_cast(weight_max_guard_->addr_), - 0.0f, /* beta */ - param.output->mutable_data(TARGET(kXPU)), /* C */ - output_max, - bias, /* bias */ - act /* act_type */); + : input_max, // x_maxptr + reinterpret_cast(weight_max_guard_->addr_), // w_maxptr + output_max, // y_maxptr + k, // ldx + k, // ldw + n, // ldy + 1.0f, // alpha + 0.0f, // beta + bias, // bias + act); // act + CHECK_EQ(r, 0); } else if (param.precision == "int8") { bool x_trans = false; diff --git a/lite/kernels/xpu/__xpu__mmdnn_compute.cc b/lite/kernels/xpu/__xpu__mmdnn_compute.cc index edf5766678b..124968c8424 100644 --- a/lite/kernels/xpu/__xpu__mmdnn_compute.cc +++ b/lite/kernels/xpu/__xpu__mmdnn_compute.cc @@ -224,23 +224,26 @@ class MMDNNFcOp { CHECK_EQ(r, 0); in_max_by_caller = in_max_; } - - r = xdnn::fc_int16(ctx, - false, - true, - m, - n_, - k_, - 1.0f, - in, - in_max_by_caller, - weight_, - weight_max_, - 0.0f, - out, - out_max, - bias_, - act_type_); + r = xdnn::fc_fusion( + ctx, // ctx + in, // x + weight_, // w + out, // y + m, // m + n_, // n + k_, // k + false, // x_trans + true, // w_trans + in_max_by_caller, // x_maxptr + weight_max_, // w_maxptr + out_max, // y_maxptr + k_, // ldx + k_, // ldw + n_, // ldy + 1.0f, // alpha + 0.0f, // beta + bias_, // bias + act_type_); // act CHECK_EQ(r, 0); } }; diff --git a/lite/kernels/xpu/matmul_compute.cc b/lite/kernels/xpu/matmul_compute.cc index 5399a71298b..1ef45cf626d 100644 --- a/lite/kernels/xpu/matmul_compute.cc +++ b/lite/kernels/xpu/matmul_compute.cc @@ -66,17 +66,27 @@ void MatMulCompute::Run() { int r = 0; if (mat_dim_a.batch_size_ == 0 || mat_dim_a.batch_size_ == 1) { - r = xdnn::fc_int16(ctx.GetRawContext(), /* context */ - mat_dim_a.trans_, /* TransA */ - mat_dim_b.trans_, /* TransB */ - mat_dim_a.height_, /* m */ - mat_dim_b.width_, /* n */ - mat_dim_a.width_, /* k */ - param.alpha, /* alpha */ - x->data(), /* A */ - y->data(), /* B */ - 0.0f, /* beta */ - out->mutable_data(TARGET(kXPU)) /* C */); + r = xdnn::fc_fusion( + ctx.GetRawContext(), // ctx + x->data(), // x + y->data(), // w + out->mutable_data(TARGET(kXPU)), // y + mat_dim_a.height_, // m + mat_dim_b.width_, // n + mat_dim_a.width_, // k + mat_dim_a.trans_, // x_trans + mat_dim_b.trans_, // w_trans + nullptr, // x_maxptr + nullptr, // w_maxptr + nullptr, // y_maxptr + lda, // ldx + ldb, // ldw + ldc, // ldy + param.alpha, // alpha + 0.0f, // beta + nullptr, // bias + xdnn::Activation_t::LINEAR); // act + } else { // batch matmul r = xdnn::gemm_strided_batched_int16( diff --git a/lite/kernels/xpu/mul_compute.cc b/lite/kernels/xpu/mul_compute.cc index 8aa93a9c8b8..188129adf27 100644 --- a/lite/kernels/xpu/mul_compute.cc +++ b/lite/kernels/xpu/mul_compute.cc @@ -44,18 +44,27 @@ void MulCompute::Run() { int k = x_matrix.dims()[1]; int n = y_matrix.dims()[1]; - int r = - xdnn::fc_int16(ctx.GetRawContext(), /* context */ - false, /* TransA */ - false, /* TransB */ - m, - n, - k, - 1.0f, /* alpha */ - x_matrix.data(), /* A */ - y_matrix.data(), /* B */ - 0.0f, /* beta */ - param.output->mutable_data(TARGET(kXPU)) /* C */); + int r = xdnn::fc_fusion( + ctx.GetRawContext(), // ctx + x_matrix.data(), // x + y_matrix.data(), // w + param.output->mutable_data(TARGET(kXPU)), // y + m, // m + n, // n + k, // k + false, // x_trans + false, // w_trans + nullptr, // x_maxptr + nullptr, // w_maxptr + nullptr, // y_maxptr + k, // ldx + n, // ldw + n, // ldy + 1.0f, // alpha + 0.0f, // beta + nullptr, // bias + xdnn::Activation_t::LINEAR); // act + CHECK_EQ(r, 0); }