This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
convolution-inl.h
602 lines (569 loc) · 24 KB
/
convolution-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file convolution-inl.h
* \brief
* \ref: https://github.com/Yangqing/caffe/wiki/Convolution-in-Caffe:-a-memo
* \author Bing Xu, Jun Wu, Da Zheng
*/
#ifndef MXNET_OPERATOR_NN_CONVOLUTION_INL_H_
#define MXNET_OPERATOR_NN_CONVOLUTION_INL_H_
#include <mxnet/io.h>
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <dmlc/logging.h>
#include <dmlc/optional.h>
#include <algorithm>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "../operator_common.h"
#include "../linalg.h"
#include "./im2col.h"
namespace mxnet {
namespace op {
namespace conv {
enum ConvolutionOpInputs { kData, kWeight, kBias };
enum ConvolutionOpOutputs { kOut };
enum ConvolutionOpResource { kTempSpace };
enum ConvolutionOpCudnnTune { kOff, kLimited, kFastest };
} // namespace conv
struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
mxnet::TShape kernel;
mxnet::TShape stride;
mxnet::TShape dilate;
mxnet::TShape pad;
uint32_t num_filter;
uint32_t num_group;
uint64_t workspace;
bool no_bias;
dmlc::optional<int> cudnn_tune;
bool cudnn_off;
dmlc::optional<int> layout;
DMLC_DECLARE_PARAMETER(ConvolutionParam) {
DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (w,), (h, w) or (d, h, w)");
DMLC_DECLARE_FIELD(stride)
.set_default(mxnet::TShape(0, 0))
.describe(
"Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.");
DMLC_DECLARE_FIELD(dilate)
.set_default(mxnet::TShape(0, 0))
.describe(
"Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.");
DMLC_DECLARE_FIELD(pad)
.set_default(mxnet::TShape(0, 0))
.describe("Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.");
DMLC_DECLARE_FIELD(num_filter)
.set_lower_bound(1)
.describe("Convolution filter(channel) number");
DMLC_DECLARE_FIELD(num_group).set_default(1).describe("Number of group partitions.");
DMLC_DECLARE_FIELD(workspace).set_default(1024).set_lower_bound(0).describe(
"Maximum temporary workspace allowed (MB) in convolution."
"This parameter has two usages. When CUDNN is not used, it determines the "
"effective batch size of the convolution kernel. When CUDNN is used, it controls "
"the maximum temporary storage used for tuning the best CUDNN kernel when "
"`limited_workspace` strategy is used.");
DMLC_DECLARE_FIELD(no_bias).set_default(false).describe("Whether to disable bias parameter.");
DMLC_DECLARE_FIELD(cudnn_tune)
.add_enum("off", conv::kOff)
.add_enum("limited_workspace", conv::kLimited)
.add_enum("fastest", conv::kFastest)
.set_default(dmlc::optional<int>())
.describe("Whether to pick convolution algo by running performance test.");
DMLC_DECLARE_FIELD(cudnn_off).set_default(false).describe("Turn off cudnn for this layer.");
DMLC_DECLARE_FIELD(layout)
.add_enum("NCW", mshadow::kNCW)
.add_enum("NCHW", mshadow::kNCHW)
.add_enum("NCDHW", mshadow::kNCDHW)
.add_enum("NWC", mshadow::kNWC)
.add_enum("NHWC", mshadow::kNHWC)
.add_enum("NDHWC", mshadow::kNDHWC)
.set_default(dmlc::optional<int>())
.describe(
"Set layout for input, output and weight. Empty for\n "
"default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d."
"NHWC and NDHWC are only supported on GPU.");
}
// Adjusts kernel size for effects of dilation in the dimension `dim`.
index_t DilatedKernelSize(int dim) const {
return 1 + (kernel[dim] - 1) * dilate[dim];
}
bool operator==(const ConvolutionParam& other) const {
return this->kernel == other.kernel && this->stride == other.stride &&
this->dilate == other.dilate && this->pad == other.pad &&
this->num_filter == other.num_filter && this->num_group == other.num_group &&
this->workspace == other.workspace && this->no_bias == other.no_bias &&
this->cudnn_tune == other.cudnn_tune && this->cudnn_off == other.cudnn_off &&
this->layout == other.layout;
}
std::string CudnnTune2String(int cudnn_tune) {
switch (cudnn_tune) {
case conv::kOff:
return "off";
case conv::kLimited:
return "limited_workspace";
case conv::kFastest:
return "fastest";
default:
LOG(FATAL) << "Unknown cudnn_tune enum " << cudnn_tune;
}
LOG(FATAL) << "should not reach here ";
return "";
}
std::string Layout2String(int layout) {
switch (layout) {
case mshadow::kNCW:
return "NCW";
case mshadow::kNCHW:
return "NCHW";
case mshadow::kNCDHW:
return "NCDHW";
case mshadow::kNHWC:
return "NHWC";
case mshadow::kNDHWC:
return "NDHWC";
default:
LOG(FATAL) << "Unknown layout enum " << layout;
}
LOG(FATAL) << "should not reach here ";
return "";
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream kernel_s, stride_s, dilate_s, pad_s, num_filter_s, num_group_s, workspace_s,
no_bias_s, cudnn_tune_s, cudnn_off_s, layout_s;
kernel_s << kernel;
stride_s << stride;
dilate_s << dilate;
pad_s << pad;
num_filter_s << num_filter;
num_group_s << num_group;
workspace_s << workspace;
no_bias_s << no_bias;
cudnn_tune_s << cudnn_tune;
cudnn_off_s << cudnn_off;
layout_s << layout;
(*dict)["kernel"] = kernel_s.str();
(*dict)["stride"] = stride_s.str();
(*dict)["dilate"] = dilate_s.str();
(*dict)["pad"] = pad_s.str();
(*dict)["num_filter"] = num_filter_s.str();
(*dict)["num_group"] = num_group_s.str();
(*dict)["workspace"] = workspace_s.str();
(*dict)["no_bias"] = no_bias_s.str();
if (cudnn_tune.has_value()) {
(*dict)["cudnn_tune"] = CudnnTune2String(cudnn_tune.value());
} else {
(*dict)["cudnn_tune"] = cudnn_tune_s.str();
}
(*dict)["cudnn_off"] = cudnn_off_s.str();
if (layout.has_value()) {
(*dict)["layout"] = Layout2String(layout.value());
} else {
(*dict)["layout"] = layout_s.str();
}
}
};
void ConvolutionParamParser(nnvm::NodeAttrs* attrs);
typedef ParamOpSign<ConvolutionParam> ConvSignature;
} // namespace op
} // namespace mxnet
namespace std {
template <>
struct hash<mxnet::op::ConvolutionParam> {
size_t operator()(const mxnet::op::ConvolutionParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.kernel);
ret = dmlc::HashCombine(ret, val.stride);
ret = dmlc::HashCombine(ret, val.dilate);
ret = dmlc::HashCombine(ret, val.pad);
ret = dmlc::HashCombine(ret, val.num_filter);
ret = dmlc::HashCombine(ret, val.num_group);
ret = dmlc::HashCombine(ret, val.workspace);
ret = dmlc::HashCombine(ret, val.no_bias);
ret = dmlc::HashCombine(ret, val.cudnn_tune);
ret = dmlc::HashCombine(ret, val.cudnn_off);
ret = dmlc::HashCombine(ret, val.layout);
return ret;
}
};
} // namespace std
namespace mxnet {
namespace op {
template <typename xpu, typename DType>
class ConvolutionOp {
public:
void Init(ConvolutionParam p) {
this->param_ = p;
// convert MBytes first to Bytes and then to elements.
param_.workspace = (param_.workspace << 20) / sizeof(DType);
if (param_.layout.has_value()) {
CHECK(param_.layout.value() == mshadow::kNCW || param_.layout.value() == mshadow::kNCHW ||
param_.layout.value() == mshadow::kNCDHW)
<< "Only support NCW, NCHW and NCDHW layout";
}
}
void Forward(const OpContext& ctx,
const std::vector<TBlob>& in_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& out_data) {
using namespace mshadow;
using namespace mshadow::expr;
size_t expected = param_.no_bias ? 2 : 3;
CHECK_EQ(in_data.size(), expected);
CHECK_EQ(out_data.size(), 1U);
// CHECK_EQ(req[conv::kOut], kWriteTo);
_Forward(ctx,
in_data[conv::kData],
in_data[conv::kWeight],
param_.no_bias ? nullptr : &in_data[conv::kBias],
req[conv::kOut],
out_data[conv::kOut]);
}
void Backward(const OpContext& ctx,
const std::vector<TBlob>& out_grad,
const std::vector<TBlob>& in_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(out_grad.size(), 1U);
// We expect 2 inputs: in data and weight. We don't need bias for
// computing gradient.
size_t expected = param_.no_bias ? 2 : 3;
CHECK_EQ(in_data.size(), expected);
CHECK_EQ(in_grad.size(), expected);
CHECK_EQ(req.size(), expected);
CHECK_EQ(in_data[conv::kWeight].CheckContiguous(), true);
auto workspace = _BackwardData(
ctx, out_grad[conv::kOut], in_data[conv::kWeight], req[conv::kData], in_grad[conv::kData]);
_BackwardWeightsBias(workspace,
ctx,
out_grad[conv::kOut],
in_data[conv::kData],
req[conv::kWeight],
in_grad[conv::kWeight],
param_.no_bias ? OpReqType() : req[conv::kBias],
param_.no_bias ? nullptr : &in_grad[conv::kBias]);
}
private:
Tensor<xpu, 1, DType> _Forward(const OpContext& ctx,
const TBlob& in_data,
const TBlob& in_weights,
const TBlob* in_bias,
const OpReqType req,
const TBlob& out_data) {
using namespace mshadow;
using namespace mshadow::expr;
LayerSetUp(in_data.shape_, out_data.shape_);
Stream<xpu>* s = ctx.get_stream<xpu>();
Tensor<xpu, 1, DType> workspace;
// initialize weight and col_buffer 3D tensors for using gemm
index_t M = conv_out_channels_ / group_;
index_t N = conv_out_spatial_dim_;
index_t K = kernel_dim_;
Tensor<xpu, 3, DType> weight_3d =
in_weights.get_with_shape<xpu, 3, DType>(Shape3(group_, M, K), s);
Tensor<xpu, 4, DType> output_4d =
out_data.get_with_shape<xpu, 4, DType>(Shape4(num_, group_, M, N), s);
// no need to allocating memory and reordering in memory
if (is_1x1_) {
Tensor<xpu, 4, DType> input_4d =
in_data.get_with_shape<xpu, 4, DType>(Shape4(num_, group_, K, N), s);
for (index_t n = 0; n < num_; ++n) {
Tensor<xpu, 3, DType> input_3d = input_4d[n];
Tensor<xpu, 3, DType> output_3d = output_4d[n];
for (index_t g = 0; g < group_; ++g) {
linalg_gemm(weight_3d[g], input_3d[g], output_3d[g], false, false, s, req);
}
}
} else {
// allocate workspace for col_buffer
workspace = ctx.requested[conv::kTempSpace].get_space_typed<xpu, 1, DType>(
Shape1(col_buffer_size_), s);
// calculate the shape of col_buffer
mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1);
col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size();
for (int i = 1; i < col_buffer_shape.ndim(); ++i) {
col_buffer_shape[i] = out_data.shape_[i + 1];
}
// create a column buffer using workspace and col_buffer_shape
TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType<DType>::kFlag);
Tensor<xpu, 3, DType> col_buffer_3d =
col_buffer.get_with_shape<xpu, 3, DType>(Shape3(group_, K, N), s);
for (index_t n = 0; n < num_; ++n) {
// transform image to col_buffer in order to use gemm
im2col(s,
in_data.dptr<DType>() + n * input_dim_,
in_data.shape_,
col_buffer.shape_,
param_.kernel,
param_.pad,
param_.stride,
param_.dilate,
col_buffer.dptr<DType>());
Tensor<xpu, 3, DType> output_3d = output_4d[n];
for (index_t g = 0; g < group_; ++g) {
// Legacy approach shown here for comparison:
// Assign(output_3d[g], req, dot(weight_3d[g], col_buffer_3d[g]));
linalg_gemm(weight_3d[g], col_buffer_3d[g], output_3d[g], false, false, s, req);
}
}
}
if (bias_term_) {
CHECK(in_bias != nullptr);
Tensor<xpu, 1, DType> bias = in_bias->get<xpu, 1, DType>(s);
Tensor<xpu, 3, DType> output_3d = out_data.get_with_shape<xpu, 3, DType>(
Shape3(num_, conv_out_channels_, conv_out_spatial_dim_), s);
// has bias term, broadcast it to the same shape of output_3d in channel dim
output_3d += mshadow::expr::broadcast<1>(bias, output_3d.shape_);
}
return workspace;
}
// Computes dLoss/dData
Tensor<xpu, 1, DType> _BackwardData(const OpContext& ctx,
const TBlob& out_grad,
const TBlob& weights,
const OpReqType data_grad_req,
const TBlob& data_grad_dst) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(weights.CheckContiguous(), true);
LayerSetUp(data_grad_dst.shape_, out_grad.shape_);
Stream<xpu>* s = ctx.get_stream<xpu>();
Tensor<xpu, 1, DType> workspace;
// initialize weight and col_buffer 3D tensors for using gemm
index_t M = kernel_dim_;
index_t N = conv_out_spatial_dim_;
index_t K = conv_out_channels_ / group_;
Tensor<xpu, 3, DType> weight_3d =
weights.get_with_shape<xpu, 3, DType>(Shape3(group_, K, M), s);
Tensor<xpu, 4, DType> out_grad_4d =
out_grad.get_with_shape<xpu, 4, DType>(Shape4(num_, group_, K, N), s);
// no need to allocating memory and reordering in memory
if (is_1x1_) {
Tensor<xpu, 4, DType> in_grad_4d =
data_grad_dst.get_with_shape<xpu, 4, DType>(Shape4(num_, group_, M, N), s);
for (index_t n = 0; n < num_; ++n) {
Tensor<xpu, 3, DType> in_grad_3d = in_grad_4d[n];
Tensor<xpu, 3, DType> out_grad_3d = out_grad_4d[n];
for (index_t g = 0; g < group_; ++g) {
linalg_gemm(weight_3d[g], out_grad_3d[g], in_grad_3d[g], true, false, s);
}
}
} else {
// allocate workspace for col_buffer
workspace = ctx.requested[conv::kTempSpace].get_space_typed<xpu, 1, DType>(
Shape1(col_buffer_size_), s);
// calculate the shape of col_buffer
mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1);
col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size();
for (int i = 1; i < col_buffer_shape.ndim(); ++i) {
col_buffer_shape[i] = out_grad.shape_[i + 1];
}
// create a column buffer using workspace and col_buffer_shape
TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType<DType>::kFlag);
Tensor<xpu, 3, DType> col_buffer_3d =
col_buffer.get_with_shape<xpu, 3, DType>(Shape3(group_, M, N), s);
for (index_t n = 0; n < num_; ++n) {
Tensor<xpu, 3, DType> out_grad_3d = out_grad_4d[n];
for (index_t g = 0; g < group_; ++g) {
linalg_gemm(weight_3d[g], out_grad_3d[g], col_buffer_3d[g], true, false, s);
}
col2im(s,
col_buffer.dptr<DType>(),
data_grad_dst.shape_,
col_buffer.shape_,
param_.kernel,
param_.pad,
param_.stride,
param_.dilate,
data_grad_dst.dptr<DType>() + n * input_dim_,
data_grad_req);
}
}
return workspace;
}
// Computes dLoss/dWeights and dLoss/dBias
void _BackwardWeightsBias(Tensor<xpu, 1, DType> workspace,
const OpContext& ctx,
const TBlob& out_grad,
const TBlob& data,
const OpReqType weights_grad_req,
const TBlob& weights_grad_dst,
const OpReqType bias_grad_req,
const TBlob* const bias_grad_dst) {
using namespace mshadow;
using namespace mshadow::expr;
LayerSetUp(data.shape_, out_grad.shape_);
Stream<xpu>* s = ctx.get_stream<xpu>();
// initialize weight and col_buffer 3D tensors for using gemm
index_t M = kernel_dim_;
index_t N = conv_out_spatial_dim_;
index_t K = conv_out_channels_ / group_;
Tensor<xpu, 4, DType> out_grad_4d =
out_grad.get_with_shape<xpu, 4, DType>(Shape4(num_, group_, K, N), s);
Tensor<xpu, 3, DType> dweight_3d =
weights_grad_dst.get_with_shape<xpu, 3, DType>(Shape3(group_, K, M), s);
// no need to allocating memory and reordering in memory
if (is_1x1_) {
Tensor<xpu, 4, DType> input_4d =
data.get_with_shape<xpu, 4, DType>(Shape4(num_, group_, M, N), s);
for (index_t n = 0; n < num_; ++n) {
Tensor<xpu, 3, DType> input_3d = input_4d[n];
Tensor<xpu, 3, DType> out_grad_3d = out_grad_4d[n];
for (index_t g = 0; g < group_; ++g) {
auto request = (n == 0) ? weights_grad_req : kAddTo;
linalg_gemm(out_grad_3d[g], input_3d[g], dweight_3d[g], false, true, s, request);
}
}
} else {
// allocate workspace for col_buffer
if (workspace.dptr_ == nullptr) {
workspace = ctx.requested[conv::kTempSpace].get_space_typed<xpu, 1, DType>(
Shape1(col_buffer_size_), s);
}
// calculate the shape of col_buffer
mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1);
col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size();
for (int i = 1; i < col_buffer_shape.ndim(); ++i) {
col_buffer_shape[i] = out_grad.shape_[i + 1];
}
// create a column buffer using workspace and col_buffer_shape
TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType<DType>::kFlag);
Tensor<xpu, 3, DType> col_buffer_3d =
col_buffer.get_with_shape<xpu, 3, DType>(Shape3(group_, M, N), s);
for (index_t n = 0; n < num_; ++n) {
Tensor<xpu, 3, DType> out_grad_3d = out_grad_4d[n];
// dWeight should accumulate across the batch and group
im2col(s,
data.dptr<DType>() + n * input_dim_,
data.shape_,
col_buffer.shape_,
param_.kernel,
param_.pad,
param_.stride,
param_.dilate,
col_buffer.dptr<DType>());
for (index_t g = 0; g < group_; ++g) {
auto request = (n == 0) ? weights_grad_req : kAddTo;
linalg_gemm(out_grad_3d[g], col_buffer_3d[g], dweight_3d[g], false, true, s, request);
}
}
}
// bias gradient
if (bias_term_) {
CHECK(bias_grad_dst != nullptr);
Tensor<xpu, 1, DType> dbias = bias_grad_dst->get<xpu, 1, DType>(s);
Tensor<xpu, 3, DType> dout = out_grad.get_with_shape<xpu, 3, DType>(
Shape3(num_, conv_out_channels_, conv_out_spatial_dim_), s);
ASSIGN_DISPATCH(dbias, bias_grad_req, sumall_except_dim<1>(dout));
}
}
void LayerSetUp(const mxnet::TShape& ishape, const mxnet::TShape& oshape) {
channel_axis_ = 1; // hard code channel axis
const index_t first_spatial_axis = channel_axis_ + 1;
const int num_axes = param_.kernel.ndim() + 2;
num_spatial_axes_ = num_axes - first_spatial_axis;
is_1x1_ = true;
for (int i = 0; i < param_.kernel.ndim(); ++i) {
is_1x1_ &= param_.kernel[i] == 1 && param_.stride[i] == 1 && param_.pad[i] == 0;
if (!is_1x1_)
break;
}
// batch size
num_ = ishape[0];
// number of input channels
channels_ = ishape[1];
group_ = param_.num_group;
conv_out_channels_ = param_.num_filter;
conv_in_channels_ = channels_;
bias_term_ = !param_.no_bias;
kernel_dim_ = conv_in_channels_ / group_ * param_.kernel.Size();
weight_offset_ = conv_out_channels_ * kernel_dim_ / group_;
conv_out_spatial_dim_ = oshape.ProdShape(2, oshape.ndim());
col_offset_ = kernel_dim_ * conv_out_spatial_dim_;
output_offset_ = conv_out_channels_ * conv_out_spatial_dim_ / group_;
// size of the column buffer used for storing im2col-ed pixels
col_buffer_size_ = kernel_dim_ * group_ * conv_out_spatial_dim_;
// input/output image size (#channels * height * width)
input_dim_ = ishape.ProdShape(1, ishape.ndim());
output_dim_ = oshape.ProdShape(1, oshape.ndim());
num_kernels_im2col_ = conv_in_channels_ * conv_out_spatial_dim_;
num_kernels_col2im_ = input_dim_;
}
private:
ConvolutionParam param_;
index_t channel_axis_; // channel axis of the input
index_t channels_; // number of channels of input image
index_t num_spatial_axes_; // number of spatial axes
index_t num_; // batch size
index_t group_; // number of groups
index_t conv_out_channels_; // number of output channels (num_filter)
index_t conv_out_spatial_dim_; // number of pixels of output images per channel
index_t conv_in_channels_; // number of input channels
index_t kernel_dim_; // number of input channels per group * kernel size
index_t weight_offset_; // number of output channels per group * kernel_dim_
index_t col_offset_;
index_t output_offset_;
index_t col_buffer_size_;
index_t input_dim_;
index_t output_dim_;
index_t num_kernels_im2col_;
index_t num_kernels_col2im_;
bool bias_term_; // has bias term?
bool is_1x1_;
template <typename xpu_, typename DType_>
friend class DeconvolutionOp;
}; // class ConvolutionOp
template <typename xpu>
void ConvolutionCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
MSHADOW_REAL_TYPE_SWITCH(inputs[conv::kData].type_flag_, DType, {
ConvolutionOp<xpu, DType> op;
op.Init(param);
op.Forward(ctx, inputs, req, outputs);
});
}
template <typename xpu>
void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
std::vector<TBlob> in_data(inputs.begin() + 1, inputs.end());
const TBlob& out_grad = inputs[0];
const std::vector<TBlob>& in_grad = outputs;
MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, {
ConvolutionOp<xpu, DType> op;
op.Init(param);
op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NN_CONVOLUTION_INL_H_