Skip to content

Commit

Permalink
fix gemm vulkan without C
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 10, 2024
1 parent 9c3057b commit 0b7755d
Showing 1 changed file with 94 additions and 50 deletions.
144 changes: 94 additions & 50 deletions src/layer/vulkan/gemm_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,56 +182,78 @@ int Gemm_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM
{
const VkMat& A0 = constantA ? A_data_gpu : bottom_blobs[0];
const VkMat& B0 = constantB ? B_data_gpu : constantA ? bottom_blobs[0] : bottom_blobs[1];
const VkMat& C0 = constantC ? C_data_gpu : bottom_blobs[bottom_blobs.size() - 1];

VkMat A;
VkMat B;
VkMat C;
vkdev->convert_packing(A0, A, 1, cmd, opt);
vkdev->convert_packing(B0, B, 1, cmd, opt);
vkdev->convert_packing(C0, C, 1, cmd, opt);

const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h);
const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w;
const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w;

int broadcast_type_C;
VkMat C;
int broadcast_type_C = -1;
if (constantC)
{
vkdev->convert_packing(C_data_gpu, C, 1, cmd, opt);
broadcast_type_C = constant_broadcast_type_C;
}
else
{
if (C.dims == 1 && C.w == 1)
{
// scalar
broadcast_type_C = 0;
}
if (C.dims == 1 && C.w == M)
VkMat C0;
if (constantA && constantB)
{
// M
// auto broadcast from h to w is the ncnn-style convention
broadcast_type_C = 1;
C0 = bottom_blobs.size() == 1 ? bottom_blobs[0] : VkMat();
}
if (C.dims == 1 && C.w == N)
else if (constantA)
{
// N
broadcast_type_C = 4;
C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkMat();
}
if (C.dims == 2 && C.w == 1 && C.h == M)
else if (constantB)
{
// Mx1
broadcast_type_C = 2;
C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkMat();
}
if (C.dims == 2 && C.w == N && C.h == M)
else
{
// MxN
broadcast_type_C = 3;
C0 = bottom_blobs.size() == 3 ? bottom_blobs[2] : VkMat();
}
if (C.dims == 2 && C.w == N && C.h == 1)

if (!C0.empty())
{
// 1xN
broadcast_type_C = 4;
vkdev->convert_packing(C0, C, 1, cmd, opt);

if (C0.dims == 1 && C0.w == 1)
{
// scalar
broadcast_type_C = 0;
}
if (C0.dims == 1 && C0.w == M)
{
// M
// auto broadcast from h to w is the ncnn-style convention
broadcast_type_C = 1;
}
if (C0.dims == 1 && C0.w == N)
{
// N
broadcast_type_C = 4;
}
if (C0.dims == 2 && C0.w == 1 && C0.h == M)
{
// Mx1
broadcast_type_C = 2;
}
if (C0.dims == 2 && C0.w == N && C0.h == M)
{
// MxN
broadcast_type_C = 3;
}
if (C0.dims == 2 && C0.w == N && C0.h == 1)
{
// 1xN
broadcast_type_C = 4;
}
}
}

Expand Down Expand Up @@ -314,56 +336,78 @@ int Gemm_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::vecto
{
const VkImageMat& A0 = constantA ? A_data_gpu_image : bottom_blobs[0];
const VkImageMat& B0 = constantB ? B_data_gpu_image : constantA ? bottom_blobs[0] : bottom_blobs[1];
const VkImageMat& C0 = constantC ? C_data_gpu_image : bottom_blobs[bottom_blobs.size() - 1];

VkImageMat A;
VkImageMat B;
VkImageMat C;
vkdev->convert_packing(A0, A, 1, cmd, opt);
vkdev->convert_packing(B0, B, 1, cmd, opt);
vkdev->convert_packing(C0, C, 1, cmd, opt);

const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h);
const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w;
const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w;

int broadcast_type_C;
VkImageMat C;
int broadcast_type_C = -1;
if (constantC)
{
vkdev->convert_packing(C_data_gpu_image, C, 1, cmd, opt);
broadcast_type_C = constant_broadcast_type_C;
}
else
{
if (C.dims == 1 && C.w == 1)
{
// scalar
broadcast_type_C = 0;
}
if (C.dims == 1 && C.w == M)
VkImageMat C0;
if (constantA && constantB)
{
// M
// auto broadcast from h to w is the ncnn-style convention
broadcast_type_C = 1;
C0 = bottom_blobs.size() == 1 ? bottom_blobs[0] : VkImageMat();
}
if (C.dims == 1 && C.w == N)
else if (constantA)
{
// N
broadcast_type_C = 4;
C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkImageMat();
}
if (C.dims == 2 && C.w == 1 && C.h == M)
else if (constantB)
{
// Mx1
broadcast_type_C = 2;
C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkImageMat();
}
if (C.dims == 2 && C.w == N && C.h == M)
else
{
// MxN
broadcast_type_C = 3;
C0 = bottom_blobs.size() == 3 ? bottom_blobs[2] : VkImageMat();
}
if (C.dims == 2 && C.w == N && C.h == 1)

if (!C0.empty())
{
// 1xN
broadcast_type_C = 4;
vkdev->convert_packing(C0, C, 1, cmd, opt);

if (C.dims == 1 && C.w == 1)
{
// scalar
broadcast_type_C = 0;
}
if (C.dims == 1 && C.w == M)
{
// M
// auto broadcast from h to w is the ncnn-style convention
broadcast_type_C = 1;
}
if (C.dims == 1 && C.w == N)
{
// N
broadcast_type_C = 4;
}
if (C.dims == 2 && C.w == 1 && C.h == M)
{
// Mx1
broadcast_type_C = 2;
}
if (C.dims == 2 && C.w == N && C.h == M)
{
// MxN
broadcast_type_C = 3;
}
if (C.dims == 2 && C.w == N && C.h == 1)
{
// 1xN
broadcast_type_C = 4;
}
}
}

Expand Down

0 comments on commit 0b7755d

Please sign in to comment.