Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MKLDNN] Fix _copyto (#17173)
Browse files Browse the repository at this point in the history
* fix_copyto

* only exclude _copyto

* trigger CI
  • Loading branch information
wuxun-zhang authored and pengzhao-intel committed Jan 2, 2020
1 parent 5d9cbdb commit e65fc4b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ inline void PushFComputeEx(const FComputeEx& fn,
}
// add for mkldnn OP + no mkldnn OP
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs.op, false)) {
if (!is_mkldnn.get(attrs.op, false) && exec_type != ExecType::kCrossDeviceCopy) {
std::vector<NDArray> inputs_fallback;
CreateDefaultInputs(inputs, &inputs_fallback);
fn(attrs, opctx, inputs_fallback, req, outputs);
Expand Down Expand Up @@ -543,7 +543,7 @@ inline void PushOperator(const OpStatePtr& state,
}
// add for mkldnn OP + no mkldnn OP
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs.op, false)) {
if (!is_mkldnn.get(attrs.op, false) && exec_type != ExecType::kCrossDeviceCopy) {
std::vector<NDArray> inputs_fallback;
CreateDefaultInputs(inputs, &inputs_fallback);
fcompute_ex(state, opctx, inputs_fallback, req, outputs);
Expand Down
20 changes: 11 additions & 9 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,15 +1137,17 @@ def check_quantize_net(qdtype):
quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True)
quantized_resnet18_v1(random_data)

quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype,
exclude_layers=None,
exclude_layers_match=excluded_names_match,
calib_data=calib_data,
calib_mode='naive',
num_calib_examples=num_calib_examples,
ctx=mx.current_context())
quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True)
quantized_resnet18_v1(random_data)
for mode in ['naive', 'entropy']:
qdtype = qdtype if mode is 'naive' else 'auto'
quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype,
exclude_layers=None,
exclude_layers_match=excluded_names_match,
calib_data=calib_data,
calib_mode=mode,
num_calib_examples=num_calib_examples,
ctx=mx.current_context())
quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True)
quantized_resnet18_v1(random_data)

for qdtype in ['int8', 'uint8']:
check_quantize_net(qdtype)
Expand Down

0 comments on commit e65fc4b

Please sign in to comment.