diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index edc2f2adf4c2..c6d9b8b7b969 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -479,7 +479,7 @@ inline void PushFComputeEx(const FComputeEx& fn, } // add for mkldnn OP + no mkldnn OP const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); - if (!is_mkldnn.get(attrs.op, false)) { + if (!is_mkldnn.get(attrs.op, false) && exec_type != ExecType::kCrossDeviceCopy) { std::vector inputs_fallback; CreateDefaultInputs(inputs, &inputs_fallback); fn(attrs, opctx, inputs_fallback, req, outputs); @@ -543,7 +543,7 @@ inline void PushOperator(const OpStatePtr& state, } // add for mkldnn OP + no mkldnn OP const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); - if (!is_mkldnn.get(attrs.op, false)) { + if (!is_mkldnn.get(attrs.op, false) && exec_type != ExecType::kCrossDeviceCopy) { std::vector inputs_fallback; CreateDefaultInputs(inputs, &inputs_fallback); fcompute_ex(state, opctx, inputs_fallback, req, outputs); diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 8f433524b325..ec6ddfdf67f4 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -1137,15 +1137,17 @@ def check_quantize_net(qdtype): quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True) quantized_resnet18_v1(random_data) - quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype, - exclude_layers=None, - exclude_layers_match=excluded_names_match, - calib_data=calib_data, - calib_mode='naive', - num_calib_examples=num_calib_examples, - ctx=mx.current_context()) - quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True) - quantized_resnet18_v1(random_data) + for mode in ['naive', 'entropy']: + qdtype = qdtype if mode is 'naive' else 'auto' + quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype, + exclude_layers=None, + exclude_layers_match=excluded_names_match, + calib_data=calib_data, + calib_mode=mode, + num_calib_examples=num_calib_examples, + ctx=mx.current_context()) + quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True) + quantized_resnet18_v1(random_data) for qdtype in ['int8', 'uint8']: check_quantize_net(qdtype)