Skip to content

Commit

Permalink
Add Q16x8 Depthwise Conv Support (#2140)
Browse files Browse the repository at this point in the history
Adds support for 16-bit activations + 8-bit weights for depthwise convolution in the reference kernel. Uses 64-bit bias to match TFLite. Also adds passthrough to the q16x8 reference kernel for Xtensa, CEVA, and ARC (CMSIS already has it's own implementation).

Tested:
depthwise_conv_test

BUG=2141
  • Loading branch information
mbrooksx authored Aug 1, 2023
1 parent ca74563 commit a7846a1
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 8 deletions.
24 changes: 24 additions & 0 deletions tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bias, output);
}
break;
case kTfLiteInt16: {
switch (filter->type) {
case kTfLiteInt8: {
reference_integer_ops::DepthwiseConvPerChannel(
DepthwiseConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default:
MicroPrintf("Filter type %s (%d) for input type %s not supported.",
TfLiteTypeGetName(filter->type), filter->type,
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
break;
}
default:
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/lite/micro/kernels/ceva/depthwise_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,30 @@ TfLiteStatus EvalCEVA(TfLiteContext* context, TfLiteNode* node) {
EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
output);
break;
case kTfLiteInt16: {
switch (filter->type) {
case kTfLiteInt8: {
reference_integer_ops::DepthwiseConvPerChannel(
DepthwiseConvParamsQuantized(*params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default:
MicroPrintf("Filter type %s (%d) for input type %s not supported.",
TfLiteTypeGetName(filter->type), filter->type,
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
break;
}
default:
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
Expand Down
29 changes: 27 additions & 2 deletions tensorflow/lite/micro/kernels/depthwise_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,33 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
default:
MicroPrintf("Filter type %s (%d) not supported.",
TfLiteTypeGetName(filter->type), filter->type);
MicroPrintf("Filter type %s (%d) for input type %s not supported.",
TfLiteTypeGetName(filter->type), filter->type,
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
break;
}
case kTfLiteInt16: {
switch (filter->type) {
case kTfLiteInt8: {
reference_integer_ops::DepthwiseConvPerChannel(
DepthwiseConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default:
MicroPrintf("Filter type %s (%d) for input type %s not supported.",
TfLiteTypeGetName(filter->type), filter->type,
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
break;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/kernels/depthwise_conv_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
context,
input->type == filter->type ||
(input->type == kTfLiteInt8 &&
(filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
(filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)) ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8),
"Hybrid models are not supported on TFLite Micro.");

if (filter->type == kTfLiteInt4) {
Expand Down
89 changes: 84 additions & 5 deletions tensorflow/lite/micro/kernels/depthwise_conv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ TfLiteStatus ValidateDepthwiseConvGoldens(
return kTfLiteOk;
}

template <typename T, typename BiasT>
void TestDepthwiseConvQuantizedPerChannel(
int* input_dims_data, const float* input_data, int8_t* input_quantized,
int* input_dims_data, const float* input_data, T* input_quantized,
float input_scale, int input_zero_point, int* filter_dims_data,
const float* filter_data, int8_t* filter_data_quantized,
int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
int* bias_dims_data, const float* bias_data, BiasT* bias_data_quantized,
int* output_dims_data, const float* expected_output_data,
int8_t* expected_output_data_quantized, int8_t* output_data,
float output_scale, int output_zero_point,
TfLiteDepthwiseConvParams* conv_params,
T* expected_output_data_quantized, T* output_data, float output_scale,
int output_zero_point, TfLiteDepthwiseConvParams* conv_params,
TfLiteType filter_packed_type = kTfLiteNoType) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data);
Expand Down Expand Up @@ -147,6 +147,42 @@ void TestDepthwiseConvQuantizedPerChannel(
1.0, tensors_size, tensors));
}

void TestDepthwiseConvQuantizedPerChannel(
int* input_dims_data, const float* input_data, int8_t* input_quantized,
float input_scale, int input_zero_point, int* filter_dims_data,
const float* filter_data, int8_t* filter_data_quantized,
int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
int* output_dims_data, const float* expected_output_data,
int8_t* expected_output_data_quantized, int8_t* output_data,
float output_scale, int output_zero_point,
TfLiteDepthwiseConvParams* conv_params,
TfLiteType filter_packed_type = kTfLiteNoType) {
return TestDepthwiseConvQuantizedPerChannel<int8_t, int32_t>(
input_dims_data, input_data, input_quantized, input_scale,
input_zero_point, filter_dims_data, filter_data, filter_data_quantized,
bias_dims_data, bias_data, bias_data_quantized, output_dims_data,
expected_output_data, expected_output_data_quantized, output_data,
output_scale, output_zero_point, conv_params, filter_packed_type);
}

void TestDepthwiseConvQuantizedPerChannel(
int* input_dims_data, const float* input_data, int16_t* input_quantized,
float input_scale, int input_zero_point, int* filter_dims_data,
const float* filter_data, int8_t* filter_data_quantized,
int* bias_dims_data, const float* bias_data, int64_t* bias_data_quantized,
int* output_dims_data, const float* expected_output_data,
int16_t* expected_output_data_quantized, int16_t* output_data,
float output_scale, int output_zero_point,
TfLiteDepthwiseConvParams* conv_params,
TfLiteType filter_packed_type = kTfLiteNoType) {
return TestDepthwiseConvQuantizedPerChannel<int16_t, int64_t>(
input_dims_data, input_data, input_quantized, input_scale,
input_zero_point, filter_dims_data, filter_data, filter_data_quantized,
bias_dims_data, bias_data, bias_data_quantized, output_dims_data,
expected_output_data, expected_output_data_quantized, output_data,
output_scale, output_zero_point, conv_params, filter_packed_type);
}

// Xtensa kernels do not support float activations., and the corresponding tests
// are disabled. As a result, helper functions that are only needed for float
// kernel tests also need to be ifdef'd out to avoid build errors due to unused
Expand Down Expand Up @@ -989,4 +1025,47 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannel) {
output_scale, output_zero_point, &conv_params);
}

TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16InputInt8Filter) {
const int input_elements = 12;
int input_shape[] = {4, 1, 3, 2, 2};
const float input_values[] = {-547, 108, -682, 540, -161, -539, 9, -482,
-859, 84, 153, -726, 523, 702, -172, -936};
const int filter_elements = 16;
int filter_shape[] = {4, 1, 2, 2, 4};
const float filter_values[] = {1, 2, 3, 4, -9, 10, -11, 12,
5, 6, 7, 8, 13, -14, 15, -16};
const int bias_elements = 4;
int bias_shape[] = {4, 1, 1, 1, 4};
const int output_elements = 8;
const float bias_values[] = {1, 2, 3, 4};
const float golden[] = {
4894, -9009, -16596, 10268, -2564, -7483, -6599, 4356,
};
int output_shape[] = {4, 1, 2, 1, 4};
const int output_dims_count = 8;
int16_t output_data[output_dims_count];

const float input_scale = 0.5;
const float output_scale = 1.0f;
const int input_zero_point = 0;
const int output_zero_point = 0;

int16_t input_quantized[input_elements];
int8_t filter_quantized[filter_elements];
int64_t bias_quantized[bias_elements];
int16_t golden_quantized[output_elements];

TfLiteDepthwiseConvParams conv_params;
conv_params.activation = kTfLiteActNone;
conv_params.dilation_width_factor = 1;
conv_params.dilation_height_factor = 1;
conv_params.stride_height = 1;
conv_params.stride_width = 1;

tflite::testing::TestDepthwiseConvQuantizedPerChannel(
input_shape, input_values, input_quantized, input_scale, input_zero_point,
filter_shape, filter_values, filter_quantized, bias_shape, bias_values,
bias_quantized, output_shape, golden, golden_quantized, output_data,
output_scale, output_zero_point, &conv_params);
}
TF_LITE_MICRO_TESTS_END
37 changes: 37 additions & 0 deletions tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, DepthwiseConvPrepare(context, node));
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);

// For int16 input, only fallback to the reference kernel is used
// so there is no need to prepare the Hifi/Vision kernel.
if (input->type == kTfLiteInt16) {
micro_context->DeallocateTempTfLiteTensor(input);
return kTfLiteOk;
}
micro_context->DeallocateTempTfLiteTensor(input);

#if defined(HIFI4) || defined(HIFI5)
TF_LITE_ENSURE_OK(context, DepthwiseConvPrepareHifi(context, node));
Expand Down Expand Up @@ -114,6 +126,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
break;
}
case kTfLiteInt16: {
switch (filter->type) {
case kTfLiteInt8: {
reference_integer_ops::DepthwiseConvPerChannel(
DepthwiseConvParamsQuantized(params, op_data.reference_op_data),
op_data.reference_op_data.per_channel_output_multiplier,
op_data.reference_op_data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(&filter_int8),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default:
MicroPrintf("Filter type %s (%d) for input type %s not supported.",
TfLiteTypeGetName(filter->type), filter->type,
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
break;
}
default:
MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
input->type);
Expand Down

0 comments on commit a7846a1

Please sign in to comment.