Skip to content

Commit

Permalink
Update cmsis prepare functions to exit earlier (#2630)
Browse files Browse the repository at this point in the history
* Adds additional checks to various cmsis-nn prepare functions

BUG=1121
Authored-by: Ryan O'Shea <ryan.oshea3@arm.com>
Change-Id: Ic6481873f064a94a4dd0b4a49790842180d73dd9
  • Loading branch information
ArmRyan authored Jul 22, 2024
1 parent eec75f0 commit 5c2ac76
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 9 deletions.
11 changes: 10 additions & 1 deletion tensorflow/lite/micro/kernels/cmsis_nn/add.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -301,6 +301,15 @@ TfLiteStatus PrepareAdd(TfLiteContext* context, TfLiteNode* node) {
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);

TF_LITE_ENSURE_EQ(context, input1->type, output->type);
TF_LITE_ENSURE_MSG(
context,
input1->type == kTfLiteFloat32 || input1->type == kTfLiteInt32 ||
input1->type == kTfLiteInt16 || input1->type == kTfLiteInt8,
"Input data type not supported");
TF_LITE_ENSURE_MSG(context, input1->type == input2->type,
"Hybrid models are not supported on TFLite Micro.");

if (input1->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0);
Expand Down
10 changes: 8 additions & 2 deletions tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteType bias_type = bias != nullptr ? bias->type : kTfLiteNoType;

TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context,
input->type == kTfLiteFloat32 ||
input->type == kTfLiteInt16 ||
input->type == kTfLiteInt8,
"Input data type not supported");
TF_LITE_ENSURE_MSG(
context,
input->type == filter->type ||
(input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 && filter->type == kTfLiteInt4),
(input->type == kTfLiteInt8 &&
(filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
"Hybrid models are not supported on TFLite Micro.");

// Consistency check tensor dims
Expand Down
16 changes: 15 additions & 1 deletion tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -75,6 +75,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
micro_context->AllocateTempOutputTensor(node, kDepthwiseConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);

TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context,
input->type == kTfLiteFloat32 ||
input->type == kTfLiteInt16 ||
input->type == kTfLiteInt8,
"Input data type not supported");
TF_LITE_ENSURE_MSG(
context,
(input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 &&
(filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
"Hybrid models are not supported on TFLite Micro.");

const TfLiteType data_type = input->type;
int input_width = SizeOfDimension(input, 2);
int input_height = SizeOfDimension(input, 1);
Expand Down
16 changes: 14 additions & 2 deletions tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node, kFullyConnectedOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);

TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context,
input->type == kTfLiteFloat32 ||
input->type == kTfLiteInt16 ||
input->type == kTfLiteInt8,
"Input data type not supported");
TF_LITE_ENSURE_MSG(
context,
(input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 &&
(filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
"Hybrid models are not supported on TFLite Micro.");

const RuntimeShape filter_shape = GetTensorShape(filter);
const RuntimeShape output_shape = GetTensorShape(output);
Expand Down Expand Up @@ -125,7 +137,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_dims.c = data->accum_depth;

buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(&input_dims);
} else {
} else if (input->type == kTfLiteInt8) {
buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);

int8_t* filter_data = GetTensorData<int8_t>(filter);
Expand Down
13 changes: 12 additions & 1 deletion tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -52,6 +52,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);

TF_LITE_ENSURE_MSG(
context,
input->type == output->type ||
(input->type == kTfLiteInt8 && output->type == kTfLiteInt16),
"Input and output data types are not supported together.");
TF_LITE_ENSURE_MSG(context,
input->type == kTfLiteFloat32 ||
input->type == kTfLiteInt16 ||
input->type == kTfLiteInt8,
"Input data type not supported");

TF_LITE_ENSURE(context, node->user_data != nullptr);
CMSISNNSoftmaxParams* op_data =
static_cast<CMSISNNSoftmaxParams*>(node->user_data);
Expand Down
11 changes: 9 additions & 2 deletions tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
micro_context->AllocateTempInputTensor(node, kFilterTensor);
TF_LITE_ENSURE(context, filter != nullptr);

TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context,
input->type == kTfLiteFloat32 ||
input->type == kTfLiteInt16 ||
input->type == kTfLiteInt8,
"Input data type not supported");
TF_LITE_ENSURE_MSG(
context,
input->type == filter->type ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8),
(input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 && filter->type == kTfLiteInt8),
"Hybrid models are not supported on TFLite Micro.");

// Get height and width of the output.
Expand Down

0 comments on commit 5c2ac76

Please sign in to comment.