Skip to content

Commit

Permalink
test mha oom
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jun 19, 2024
1 parent cdb2f71 commit d2a6b6e
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions tests/test_multiheadattention_oom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "testutil.h"

static int test_multiheadattention_oom(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask)
{
const int qdim = q.w;
const int kdim = k.w;
const int vdim = v.w;

ncnn::ParamDict pd;
pd.set(0, embed_dim);
pd.set(1, num_heads);
pd.set(2, embed_dim * qdim);
pd.set(3, kdim);
pd.set(4, vdim);
pd.set(5, attn_mask);

std::vector<ncnn::Mat> weights(8);
weights[0] = RandomMat(embed_dim * qdim);
weights[1] = RandomMat(embed_dim);
weights[2] = RandomMat(embed_dim * kdim);
weights[3] = RandomMat(embed_dim);
weights[4] = RandomMat(embed_dim * vdim);
weights[5] = RandomMat(embed_dim);
weights[6] = RandomMat(qdim * embed_dim);
weights[7] = RandomMat(qdim);

std::vector<ncnn::Mat> as(3);
as[0] = q;
as[1] = k;
as[2] = v;

if (attn_mask)
{
as.push_back(RandomMat(k.h, q.h));
}

int ret = test_layer_oom("MultiHeadAttention", pd, weights, as, 1);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention_oom failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask);
}

return ret;
}

static int test_multiheadattention_0()
{
return 0
|| test_multiheadattention_oom(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0)
|| test_multiheadattention_oom(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1)
|| test_multiheadattention_oom(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0)
|| test_multiheadattention_oom(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1);
}

int main()
{
SRAND(7767517);

return test_multiheadattention_0();
}

0 comments on commit d2a6b6e

Please sign in to comment.