forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_pool3d_fixture.h
163 lines (152 loc) · 7.07 KB
/
test_pool3d_fixture.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "test_pool3d.h"
//
// Short Execute() test helper to register each test seperately by all parameters.
//
template <MLAS_POOLING_KIND PoolingKind, bool Threaded>
class Pooling3dShortExecuteTest : public MlasTestFixture<MlasPool3DTest<PoolingKind, Threaded>> {
public:
explicit Pooling3dShortExecuteTest(size_t BatchCount,
size_t InputChannels,
size_t InputDepth,
size_t InputHeight,
size_t InputWidth,
size_t KernelDepth,
size_t KernelHeight,
size_t KernelWidth,
size_t PaddingLeftDepth,
size_t PaddingLeftHeight,
size_t PaddingLeftWidth,
size_t PaddingRightDepth,
size_t PaddingRightHeight,
size_t PaddingRightWidth,
size_t StrideDepth,
size_t StrideHeight,
size_t StrideWidth)
: BatchCount_(BatchCount),
InputChannels_(InputChannels),
InputDepth_(InputDepth),
InputHeight_(InputHeight),
InputWidth_(InputWidth),
KernelDepth_(KernelDepth),
KernelHeight_(KernelHeight),
KernelWidth_(KernelWidth),
PaddingLeftDepth_(PaddingLeftDepth),
PaddingLeftHeight_(PaddingLeftHeight),
PaddingLeftWidth_(PaddingLeftWidth),
PaddingRightDepth_(PaddingRightDepth),
PaddingRightHeight_(PaddingRightHeight),
PaddingRightWidth_(PaddingRightWidth),
StrideDepth_(StrideDepth),
StrideHeight_(StrideHeight),
StrideWidth_(StrideWidth) {
}
void TestBody() override {
MlasTestFixture<MlasPool3DTest<PoolingKind, Threaded>>::mlas_tester->Test(
BatchCount_,
InputChannels_,
InputDepth_,
InputHeight_,
InputWidth_,
KernelDepth_,
KernelHeight_,
KernelWidth_,
PaddingLeftDepth_,
PaddingLeftHeight_,
PaddingLeftWidth_,
PaddingRightDepth_,
PaddingRightHeight_,
PaddingRightWidth_,
StrideDepth_,
StrideHeight_,
StrideWidth_);
}
static size_t RegisterSingleTest(size_t BatchCount,
size_t InputChannels,
size_t InputDepth,
size_t InputHeight,
size_t InputWidth,
size_t KernelDepth,
size_t KernelHeight,
size_t KernelWidth,
size_t PaddingLeftDepth,
size_t PaddingLeftHeight,
size_t PaddingLeftWidth,
size_t PaddingRightDepth,
size_t PaddingRightHeight,
size_t PaddingRightWidth,
size_t StrideDepth,
size_t StrideHeight,
size_t StrideWidth) {
std::stringstream ss;
ss << "B" << BatchCount << "/"
<< "C" << InputChannels << "/"
<< "Input_" << InputDepth << "x" << InputHeight << "x" << InputWidth << "/"
<< "Kernel" << KernelDepth << "x" << KernelHeight << "x" << KernelWidth << "/"
<< "Pad" << PaddingLeftDepth << "," << PaddingLeftHeight << "," << PaddingLeftWidth
<< "," << PaddingRightDepth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/"
<< "Stride" << StrideDepth << "," << StrideHeight << "," << StrideWidth;
auto test_name = ss.str();
testing::RegisterTest(
MlasPool3DTest<PoolingKind, Threaded>::GetTestSuiteName(),
test_name.c_str(),
nullptr,
test_name.c_str(),
__FILE__,
__LINE__,
// Important to use the fixture type as the return type here.
[=]() -> MlasTestFixture<MlasPool3DTest<PoolingKind, Threaded>>* {
return new Pooling3dShortExecuteTest<PoolingKind, Threaded>(BatchCount,
InputChannels,
InputDepth,
InputHeight,
InputWidth,
KernelDepth,
KernelHeight,
KernelWidth,
PaddingLeftDepth,
PaddingLeftHeight,
PaddingLeftWidth,
PaddingRightDepth,
PaddingRightHeight,
PaddingRightWidth,
StrideDepth,
StrideHeight,
StrideWidth);
});
return 1;
}
static size_t RegisterShortExecuteTests() {
size_t test_registered = 0;
for (unsigned i = 1; i < 64; i <<= 1) {
test_registered += RegisterSingleTest(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1);
test_registered += RegisterSingleTest(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 2, 2, 2);
test_registered += RegisterSingleTest(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1);
test_registered += RegisterSingleTest(1, 16, i, i, i, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1);
test_registered += RegisterSingleTest(1, 16, i, i, i, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1);
test_registered += RegisterSingleTest(1, 16, i, i, i, 1, i, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1);
test_registered += RegisterSingleTest(1, 16, i, i, i, 1, 1, i, 0, 0, 0, 0, 0, 0, 1, 1, 1);
}
return test_registered;
}
private:
size_t BatchCount_;
size_t InputChannels_;
size_t InputDepth_;
size_t InputHeight_;
size_t InputWidth_;
size_t KernelDepth_;
size_t KernelHeight_;
size_t KernelWidth_;
size_t PaddingLeftDepth_;
size_t PaddingLeftHeight_;
size_t PaddingLeftWidth_;
size_t PaddingRightDepth_;
size_t PaddingRightHeight_;
size_t PaddingRightWidth_;
size_t StrideDepth_;
size_t StrideHeight_;
size_t StrideWidth_;
};