Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MKLDNN] Enable signed int8 support for convolution. #13697

Merged
merged 45 commits into from
Feb 10, 2019
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
9ba9d4c
Enable s8s8 support for MKLDNN convolution.
ZhennanQin Dec 19, 2018
c642d33
Fix cpp build
ZhennanQin Dec 20, 2018
57c2423
Fix build.
ZhennanQin Dec 20, 2018
f3a2dd2
Merge remote-tracking branch 'offical' into s8_conv
ZhennanQin Dec 20, 2018
923b9ee
Fix build
ZhennanQin Dec 20, 2018
5ea756c
Remove openmp min/max reduction for windows build
ZhennanQin Dec 20, 2018
d1ba1ad
Add mkldnn_OIhw4i16o4i_s8s8 support
ZhennanQin Dec 23, 2018
7531e98
Add all s8s8 weight format
ZhennanQin Dec 23, 2018
a9fca92
Merge branch 'master' into s8_conv
ZhennanQin Dec 23, 2018
3f75e82
Change ssd quantize script.
ZhennanQin Dec 26, 2018
0acd9ce
Update
ZhennanQin Jan 2, 2019
999e54f
Merge remote-tracking branch 'offical/master' into s8_conv
ZhennanQin Jan 3, 2019
a713664
Manually cast mshadow shape size to size_t
ZhennanQin Jan 3, 2019
9129b41
Fix merge.
ZhennanQin Jan 3, 2019
159dc11
Fix perl package.
ZhennanQin Jan 3, 2019
3cd9d99
Retrigger CI
ZhennanQin Jan 3, 2019
0b98e94
Fix GPU test
ZhennanQin Jan 4, 2019
989477b
Fix GPU test
ZhennanQin Jan 4, 2019
894f13f
Rerun CI
ZhennanQin Jan 4, 2019
60c16a7
Rerun CI
ZhennanQin Jan 4, 2019
95ab7a1
Rerun CI
ZhennanQin Jan 5, 2019
c861cfc
Rerun CI
ZhennanQin Jan 5, 2019
4dacf2f
Remove weight_channelwise_scale from params.
ZhennanQin Jan 7, 2019
0a612e6
Fix
ZhennanQin Jan 10, 2019
1f7fc56
Keep API compatible.
ZhennanQin Jan 14, 2019
cffaa05
Rerun CI
ZhennanQin Jan 15, 2019
12d9f1f
Rerun CI
ZhennanQin Jan 15, 2019
807a48c
Rerun CI
ZhennanQin Jan 15, 2019
d8442e3
Merge remote-tracking branch 'offical/master' into s8_conv
ZhennanQin Jan 15, 2019
905b144
Merge remote-tracking branch 'offical/master' into s8_conv
ZhennanQin Jan 28, 2019
62e3fc0
Rerun CI
ZhennanQin Jan 28, 2019
69a6e28
Address comments.
ZhennanQin Jan 29, 2019
bf655e2
fix.
ZhennanQin Feb 1, 2019
e3a8d0a
Address debug build.
ZhennanQin Feb 2, 2019
d58311b
Add comment for next_impl
ZhennanQin Feb 2, 2019
93295ab
Rerun ci
ZhennanQin Feb 2, 2019
eae2557
Add new api MXExecutorSetMonitorCallbackEX
ZhennanQin Feb 4, 2019
11217c2
Add default value for monitor_all for cpp header.
ZhennanQin Feb 4, 2019
bfc91a6
Rerun CI
ZhennanQin Feb 5, 2019
8932fd1
Merge remote-tracking branch 'offical' into s8_conv
ZhennanQin Feb 5, 2019
fe08128
fix
ZhennanQin Feb 5, 2019
63dfdbf
script change for uint8.
ZhennanQin Feb 5, 2019
0b5e563
Merge remote-tracking branch 'upstream/master' into s8_conv
xinyu-intel Feb 8, 2019
1210b5c
trigger ci
xinyu-intel Feb 8, 2019
bff42ff
trigger ci
xinyu-intel Feb 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ __pycache__
build
cmake-build*
data
model
recommonmark
deps

Expand Down
3 changes: 2 additions & 1 deletion cpp-package/include/mxnet-cpp/monitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ class Monitor {
/*!
* \brief install callback to executor. Supports installing to multiple executors.
* \param exe The executor to install to.
* \param monitor_all If true, monitor both input and output, otherwise monitor output only.
*/
void install(Executor *exe);
void install(Executor *exe, bool monitor_all = false);
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Start collecting stats for current batch. Call before calling forward.
Expand Down
6 changes: 3 additions & 3 deletions cpp-package/include/mxnet-cpp/monitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ inline Monitor::Monitor(int interval, std::regex pattern, StatFunc stat_func)
: interval(interval), pattern(pattern), stat_func(stat_func), step(0) {
}

inline void Monitor::install(Executor *exe) {
inline void Monitor::install(Executor *exe, bool monitor_all) {
MXExecutorSetMonitorCallback(exe->handle_,
static_cast<ExecutorMonitorCallback>(&Monitor::executor_callback),
this);
static_cast<ExecutorMonitorCallback>(&Monitor::executor_callback),
this, monitor_all);
exes.push_back(exe);
}

Expand Down
42 changes: 15 additions & 27 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,24 @@ def convert_from_gluon(model_name, image_shape, classes=1000, logger=None):
symnet = mx.symbol.load_json(y.tojson())
params = net.collect_params()
args = {}
auxs = {}
auxs = {}
for param in params.values():
v = param._reduce()
k = param.name
if 'running' in k:
auxs[k] = v
else:
args[k] = v
args[k] = v
mod = mx.mod.Module(symbol=symnet, context=mx.cpu(),
label_names = ['softmax_label'])
mod.bind(for_training=False,
data_shapes=[('data', (1,) +
mod.bind(for_training=False,
data_shapes=[('data', (1,) +
tuple([int(i) for i in image_shape.split(',')]))])
mod.set_params(arg_params=args, aux_params=auxs)
dst_dir = os.path.join(dir_path, 'model')
prefix = os.path.join(dir_path, 'model', model_name)
if not os.path.isdir(dst_dir):
os.mkdir(dst_dir)
os.mkdir(dst_dir)
mod.save_checkpoint(prefix, 0)
return prefix

Expand Down Expand Up @@ -104,7 +104,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
'you can set to custom to load your pre-trained model.')
parser.add_argument('--use-gluon-model', type=bool, default=False,
help='If enabled, will download pretrained model from Gluon-CV '
'and convert to symbolic model ')
'and convert to symbolic model ')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--label-name', type=str, default='softmax_label')
parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec',
Expand All @@ -114,7 +114,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
help='number of threads for data decoding')
parser.add_argument('--num-calib-batches', type=int, default=10,
help='number of batches for calibration')
parser.add_argument('--exclude-first-conv', action='store_true', default=True,
parser.add_argument('--exclude-first-conv', action='store_true', default=False,
help='excluding quantizing the first conv layer since the'
' input data may have negative value which doesn\'t support at moment' )
parser.add_argument('--shuffle-dataset', action='store_true', default=True,
Expand All @@ -140,8 +140,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
' thresholds. This mode is expected to produce the best inference accuracy of all three'
' kinds of quantized models if the calibration dataset is representative enough of the'
' inference dataset.')
parser.add_argument('--quantized-dtype', type=str, default='uint8',
choices=['int8', 'uint8'],
parser.add_argument('--quantized-dtype', type=str, default='auto',
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
choices=['auto', 'int8', 'uint8'],
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
help='quantization destination data type for input data')
parser.add_argument('--enable-calib-quantize', type=bool, default=True,
help='If enabled, the quantize op will '
Expand Down Expand Up @@ -198,40 +198,36 @@ def save_params(fname, arg_params, aux_params, logger=None):
# get image shape
image_shape = args.image_shape

calib_layer = lambda name: name.endswith('_output') or name == "data"
exclude_first_conv = args.exclude_first_conv
excluded_sym_names = []
if args.model == 'imagenet1k-resnet-152':
rgb_mean = '0,0,0'
rgb_std = '1,1,1'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['flatten0', 'fc1', 'pooling0']
excluded_sym_names += ['flatten0', 'fc1']
if exclude_first_conv:
excluded_sym_names += ['conv0']
elif args.model == 'imagenet1k-inception-bn':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '1,1,1'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['flatten', 'fc1']
if exclude_first_conv:
excluded_sym_names += ['conv_1']
elif args.model in ['resnet50_v1', 'resnet101_v1']:
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['resnetv10_dense0_fwd', 'resnetv10_pool0_fwd']
excluded_sym_names += ['resnetv10_dense0_fwd']
if exclude_first_conv:
excluded_sym_names += ['resnetv10_conv0_fwd']
elif args.model == 'squeezenet1.0':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['squeezenet0_flatten0_flatten0']
if exclude_first_conv:
excluded_sym_names += ['squeezenet0_conv0_fwd']
elif args.model == 'mobilenet1.0':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['mobilenet0_flatten0_flatten0',
'mobilenet0_dense0_fwd',
'mobilenet0_pool0_fwd']
Expand All @@ -240,22 +236,15 @@ def save_params(fname, arg_params, aux_params, logger=None):
elif args.model == 'inceptionv3':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['inception30_dense0_fwd',
'inception30_pool0_fwd']
excluded_sym_names += ['inception30_dense0_fwd']
if exclude_first_conv:
excluded_sym_names += ['inception30_conv0_fwd']
elif args.model == 'custom':
# add rgb mean/std of your model.
rgb_mean = '0,0,0'
rgb_std = '0,0,0'
calib_layer = lambda name: name.endswith('_output')
# add layer names you donnot want to quantize.
# add conv/pool layer names that has negative inputs
# since Intel MKL-DNN only support uint8 quantization temporary.
# add all fc layer names since Intel MKL-DNN does not support temporary.
excluded_sym_names += ['layers']
# add your first conv layer names since Intel MKL-DNN only support uint8 quantization temporary.
if exclude_first_conv:
excluded_sym_names += ['layers']
else:
Expand All @@ -272,7 +261,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
logger.info('rgb_std = %s' % rgb_std)
rgb_std = [float(i) for i in rgb_std.split(',')]
std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]}
std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]}
combine_mean_std = {}
combine_mean_std.update(mean_args)
combine_mean_std.update(std_args)
Expand Down Expand Up @@ -303,8 +292,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
calib_mode=calib_mode, calib_data=data,
num_calib_examples=num_calib_batches * batch_size,
calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
label_names=(label_name,), calib_quantize_op = True,
logger=logger)
label_names=(label_name,), logger=logger)
if calib_mode == 'entropy':
suffix = '-quantized-%dbatches-entropy' % num_calib_batches
elif calib_mode == 'naive':
Expand Down
25 changes: 12 additions & 13 deletions example/ssd/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--num-calib-batches', type=int, default=5,
help='number of batches for calibration')
parser.add_argument('--exclude-first-conv', action='store_true', default=True,
parser.add_argument('--exclude-first-conv', action='store_true', default=False,
help='excluding quantizing the first conv layer since the'
' number of channels is usually not a multiple of 4 in that layer'
' which does not satisfy the requirement of cuDNN')
Expand All @@ -78,8 +78,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
' thresholds. This mode is expected to produce the best inference accuracy of all three'
' kinds of quantized models if the calibration dataset is representative enough of the'
' inference dataset.')
parser.add_argument('--quantized-dtype', type=str, default='uint8',
choices=['int8', 'uint8'],
parser.add_argument('--quantized-dtype', type=str, default='auto',
choices=['auto', 'int8', 'uint8'],
help='quantization destination data type for input data')

args = parser.parse_args()
Expand Down Expand Up @@ -115,18 +115,19 @@ def save_params(fname, arg_params, aux_params, logger=None):
# get image shape
image_shape = '3,300,300'

def calib_layer(name): return not (name.endswith('_data') or
name.endswith('_weight') or
name.endswith('_bias') or
name.endswith('_workspace'))
# Quantization layer configs
exclude_first_conv = args.exclude_first_conv
excluded_sym_names = []
rgb_mean = '123,117,104'
for i in range(1,19):
excluded_sym_names += ['flatten'+str(i)]
excluded_sym_names += ['relu4_3_cls_pred_conv',
'relu7_cls_pred_conv',
'relu4_3_loc_pred_conv',
'multibox_loc_pred',
'concat0',
'concat1']
excluded_sym_names += ['multibox_loc_pred',
'concat0',
'concat1']
if exclude_first_conv:
excluded_sym_names += ['conv1_1']

Expand Down Expand Up @@ -158,10 +159,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
ctx=ctx, excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode, calib_data=eval_iter,
num_calib_examples=num_calib_batches * batch_size,
calib_layer=None, quantized_dtype=args.quantized_dtype,
label_names=(label_name,),
calib_quantize_op=True,
logger=logger)
calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
label_names=(label_name,), logger=logger)
sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300')
param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
Expand Down
6 changes: 4 additions & 2 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
* \param num_offline number of parameters that are quantized offline
* \param offline_params array of c strings representing the names of params quantized offline
* \param quantized_dtype the quantized destination type for input data.
* \param calib_quantize whether calibrate quantize op with offline calibration data.
* \param calib_quantize **Deprecated**. quantize op will always be calibrated if could.
*/
MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle,
const mx_uint num_excluded_symbols,
Expand Down Expand Up @@ -1843,10 +1843,12 @@ MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,

/*!
* \brief set a call back to notify the completion of operation
* \param monitor_all If true, monitor both input and output, otherwise monitor output only.
*/
MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle,
ExecutorMonitorCallback callback,
void* callback_handle);
void* callback_handle,
bool monitor_all);
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
//--------------------------------------------
// Part 5: IO Interface
//--------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class Executor {
/*!
* \brief Install a callback to notify the completion of operation.
*/
virtual void SetMonitorCallback(const MonitorCallback& callback) {}
virtual void SetMonitorCallback(const MonitorCallback& callback, bool monitor_all = false) {}
}; // class executor
} // namespace mxnet
#endif // MXNET_EXECUTOR_H_
2 changes: 1 addition & 1 deletion include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ class TBlob {
CHECK(Device::kDevMask == this->dev_mask())
<< "TBlob.get: device type do not match specified type";
CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous";
CHECK_EQ(this->shape_.Size(), shape.Size())
CHECK_EQ(this->shape_.Size(), static_cast<size_t>(shape.Size()))
<< "TBlob.get_with_shape: new and old shape do not match total elements";
return mshadow::Tensor<Device, dim, DType>(dptr<DType>(), shape,
shape[dim - 1], stream);
Expand Down
7 changes: 5 additions & 2 deletions perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,17 @@ method backward(
----------
$callback : CodeRef
Takes a string and an NDArrayHandle.
$monitor_all : Bool, default 0
If true, monitor both input and output, otherwise monitor output only.
=cut

method set_monitor_callback(CodeRef $callback)
method set_monitor_callback(CodeRef $callback, Bool $monitor_all=0)
{
check_call(
AI::MXNetCAPI::ExecutorSetMonitorCallback(
$self->handle,
$callback
$callback,
$monitor_all
)
);
}
Expand Down
5 changes: 3 additions & 2 deletions perl-package/AI-MXNetCAPI/mxnet.i
Original file line number Diff line number Diff line change
Expand Up @@ -1614,10 +1614,12 @@ int MXExecutorReshape(int partial_shaping,

/*!
* \brief set a call back to notify the completion of operation
* \param monitor_all If true, monitor both input and output, otherwise monitor output only.
*/
int MXExecutorSetMonitorCallback(ExecutorHandle handle,
ExecutorMonitorCallback callback,
void* callback_handle);
void* callback_handle,
bool monitor_all);
//--------------------------------------------
// Part 5: IO Interface
//--------------------------------------------
Expand Down Expand Up @@ -2167,4 +2169,3 @@ int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** cuda_kernel_
mx_uint grid_dim_z, mx_uint block_dim_x,
mx_uint block_dim_y, mx_uint block_dim_z,
mx_uint shared_mem);

Loading