Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Incorrect output for A or B with dim=1 in GEMM #4884

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions src/layer/arm/gemm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk)
{
const int elempack = A.elempack;
const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w;

float* pp = AT;


int ii = 0;
#if __ARM_NEON
#if __aarch64__
Expand Down Expand Up @@ -3785,9 +3785,9 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c

static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -3840,8 +3840,8 @@ static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -3899,7 +3899,7 @@ static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int

static int gemm_AT_arm(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -3994,7 +3994,7 @@ static int gemm_AT_arm(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob,

static int gemm_BT_arm(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand All @@ -4018,8 +4018,8 @@ static int gemm_BT_arm(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob,
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -4329,20 +4329,20 @@ int Gemm_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
{
const Mat& B = bottom_blobs[0];
M = constantM;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = constantN;
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}

Mat C;
Expand Down Expand Up @@ -4502,9 +4502,9 @@ int Gemm_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
#if NCNN_BF16
static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -4557,8 +4557,8 @@ static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -4617,7 +4617,7 @@ static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo

static int gemm_AT_arm_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -4713,7 +4713,7 @@ static int gemm_AT_arm_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top

static int gemm_BT_arm_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand All @@ -4737,8 +4737,8 @@ static int gemm_BT_arm_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -5001,20 +5001,20 @@ int Gemm_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
{
const Mat& B = bottom_blobs[0];
M = constantM;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = constantN;
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}

Mat C;
Expand Down
26 changes: 13 additions & 13 deletions src/layer/arm/gemm_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2344,9 +2344,9 @@ static void get_optimal_tile_mnk_fp16sa(int M, int N, int K, int constant_TILE_M

static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -2399,8 +2399,8 @@ static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_bl
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -2458,7 +2458,7 @@ static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_bl

static int gemm_AT_arm_fp16sa(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -2553,7 +2553,7 @@ static int gemm_AT_arm_fp16sa(const Mat& AT, const Mat& B, const Mat& C, Mat& to

static int gemm_BT_arm_fp16sa(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand All @@ -2577,8 +2577,8 @@ static int gemm_BT_arm_fp16sa(const Mat& A, const Mat& BT, const Mat& C, Mat& to
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -2835,20 +2835,20 @@ int Gemm_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<M
{
const Mat& B = bottom_blobs[0];
M = constantM;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = constantN;
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}

Mat C;
Expand Down
26 changes: 13 additions & 13 deletions src/layer/arm/gemm_arm_vfpv4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ extern void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max

static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -86,8 +86,8 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -146,7 +146,7 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo

static int gemm_AT_arm_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -242,7 +242,7 @@ static int gemm_AT_arm_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top

static int gemm_BT_arm_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand All @@ -266,8 +266,8 @@ static int gemm_BT_arm_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -530,20 +530,20 @@ int Gemm_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
{
const Mat& B = bottom_blobs[0];
M = constantM;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = constantN;
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}

Mat C;
Expand Down
Loading
Loading