Skip to content

Commit

Permalink
Add support for int16 unidirectional lstm (ARM-software#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianLundell authored Apr 11, 2024
1 parent 1af62f8 commit 8492d82
Show file tree
Hide file tree
Showing 72 changed files with 2,938 additions and 65 deletions.
6 changes: 6 additions & 0 deletions ARM.CMSIS-NN.pdsc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
<file category="source" name="Source/PoolingFunctions/arm_avgpool_s16.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_mul_s8.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_mul_s16.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_mul_s16_batch_offset.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_add_s8.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_add_s16.c"/>
Expand All @@ -110,16 +111,21 @@
<file category="source" name="Source/NNSupportFunctions/arm_nn_depthwise_conv_nt_t_s8.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_mat_mul_core_1x_s8.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_step_s8.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_step_s16.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s16.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8_s16.c"/>
<file category="source" name="Source/ConvolutionFunctions/arm_nn_mat_mult_kernel_row_offset_s8_s16.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s16.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_s4.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_s8.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_s16.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_get_buffer_sizes_s16.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_get_buffer_sizes_s8.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_vector_sum_s8.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_vector_sum_s8_s64.c"/>
<file category="source" name="Source/LSTMFunctions/arm_lstm_unidirectional_s8.c"/>
<file category="source" name="Source/LSTMFunctions/arm_lstm_unidirectional_s16.c"/>
<file category="source" name="Source/SoftmaxFunctions/arm_softmax_s8.c"/>
<file category="source" name="Source/SoftmaxFunctions/arm_nn_softmax_common_s8.c"/>
<file category="source" name="Source/SoftmaxFunctions/arm_softmax_s8_s16.c"/>
Expand Down
14 changes: 7 additions & 7 deletions Include/arm_nn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
* Description: Public header file to contain the CMSIS-NN structs for the
* TensorFlowLite micro compliant functions
*
* $Date: 19 January 2024
* $Revision: V.3.0.0
* $Date: 26 March 2024
* $Revision: V.3.1.0
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
Expand Down Expand Up @@ -191,15 +191,15 @@ typedef struct
{
int32_t input_multiplier;
int32_t input_shift;
const int8_t *input_weights;
const int32_t *input_effective_bias; /**< Bias added with precomputed kernel_sum * lhs_offset*/
const void *input_weights;
const void *input_effective_bias; /**< Bias added with precomputed kernel_sum * lhs_offset*/

int32_t hidden_multiplier;
int32_t hidden_shift;
const int8_t *hidden_weights;
const int32_t *hidden_effective_bias; /**< Precomputed kernel_sum * lhs_offset*/
const void *hidden_weights;
const void *hidden_effective_bias; /**< Precomputed kernel_sum * lhs_offset*/

const int32_t *bias;
const void *bias;
arm_nn_activation_type activation_type;
} cmsis_nn_lstm_gate;

Expand Down
87 changes: 63 additions & 24 deletions Include/arm_nnfunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_nnfunctions.h
* Description: Public header file for CMSIS NN Library
*
* $Date: 11 March 2024
* $Revision: V.15.0.0
* $Date: 20 February 2024
* $Revision: V.15.1.0
*
* Target : Arm(R) M-Profile Architecture
Expand Down Expand Up @@ -1475,7 +1475,7 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx,
int8_t *output_data);

/**
* @brief Calculate the sum of each row in vector_data, multiply by lhs_offset and optionally add bias_data.
* @brief Calculate the sum of each row in vector_data, multiply by lhs_offset and optionally add s32 bias_data.
* @param[in, out] vector_sum_buf Buffer for vector sums
* @param[in] vector_cols Number of vector columns
* @param[in] vector_rows Number of vector rows
Expand All @@ -1492,6 +1492,24 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t lhs_offset,
const int32_t *bias_data);

/**
* @brief Calculate the sum of each row in vector_data, multiply by lhs_offset and optionally add s64 bias_data.
* @param[in, out] vector_sum_buf Buffer for vector sums
* @param[in] vector_cols Number of vector columns
* @param[in] vector_rows Number of vector rows
* @param[in] vector_data Vector of weigths data
* @param[in] lhs_offset Constant multiplied with each sum
* @param[in] bias_data Vector of bias data, added to each sum.
* @return The function returns
* <code>ARM_CMSIS_NN_SUCCESS</code> - Successful operation
*/
arm_cmsis_nn_status arm_vector_sum_s8_s64(int64_t *vector_sum_buf,
const int32_t vector_cols,
const int32_t vector_rows,
const int8_t *vector_data,
const int32_t lhs_offset,
const int64_t *bias_data);

/**
* @brief Get size of additional buffer required by arm_fully_connected_s8().
* See also arm_vector_sum_s8, which is required if buffer size is > 0.
Expand Down Expand Up @@ -2401,13 +2419,41 @@ arm_cmsis_nn_status arm_svdf_state_s16_s8(const cmsis_nn_context *input_ctx,
const cmsis_nn_dims *output_dims,
int8_t *output_data);

/**
* @brief Get size of additional buffer required by arm_svdf_s8().
* @param[in] filter_dims dimension of filter
* @return The function returns required buffer size in bytes
*
*/
int32_t arm_svdf_s8_get_buffer_size(const cmsis_nn_dims *filter_dims);

/**
* @brief Get size of additional buffer required by arm_svdf_s8() for processors with DSP extension.
* Refer to arm_svdf_s8_get_buffer_size() for function argument details.
*
* @note Intended for compilation on Host. If compiling for an Arm target, use
* arm_svdf_s8_get_buffer_size().
*
*/
int32_t arm_svdf_s8_get_buffer_size_dsp(const cmsis_nn_dims *filter_dims);

/**
* @brief Get size of additional buffer required by arm_svdf_s8() for Arm(R) Helium Architecture case.
* Refer to arm_svdf_s8_get_buffer_size() for function argument details.
*
* @note Intended for compilation on Host. If compiling for an Arm target, use
* arm_svdf_s8_get_buffer_size().
*
*/
int32_t arm_svdf_s8_get_buffer_size_mve(const cmsis_nn_dims *filter_dims);

/**
* @defgroup LSTM LSTM Layer Functions
*
*/

/**
* @brief LSTM unidirectional function with 8 bit input and output and 16 bit gate output.
* @brief LSTM unidirectional function with 8 bit input and output and 16 bit gate output, 32 bit bias.
*
* @param[in] input Pointer to input data
* @param[out] output Pointer to output data
Expand All @@ -2428,32 +2474,25 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
cmsis_nn_lstm_context *buffers);

/**
* @brief Get size of additional buffer required by arm_svdf_s8().
* @param[in] filter_dims dimension of filter
* @return The function returns required buffer size in bytes
* @brief LSTM unidirectional function with 16 bit input and output and 16 bit gate output, 64 bit bias.
*
*/
int32_t arm_svdf_s8_get_buffer_size(const cmsis_nn_dims *filter_dims);

/**
* @brief Get size of additional buffer required by arm_svdf_s8() for processors with DSP extension.
* Refer to arm_svdf_s8_get_buffer_size() for function argument details.
* @param[in] input Pointer to input data
* @param[out] output Pointer to output data
* @param[in] params Struct containing all information about the lstm operator, see arm_nn_types.
* @param[in] buffers Struct containing pointers to all temporary scratch buffers needed for the
* lstm operator, see arm_nn_types.
*
* @note Intended for compilation on Host. If compiling for an Arm target, use
* arm_svdf_s8_get_buffer_size().
*
*/
int32_t arm_svdf_s8_get_buffer_size_dsp(const cmsis_nn_dims *filter_dims);

/**
* @brief Get size of additional buffer required by arm_svdf_s8() for Arm(R) Helium Architecture case.
* Refer to arm_svdf_s8_get_buffer_size() for function argument details.
* @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
*
* @note Intended for compilation on Host. If compiling for an Arm target, use
* arm_svdf_s8_get_buffer_size().
* @details
* 1. Supported framework: TensorFlow Lite Micro
*
*/
int32_t arm_svdf_s8_get_buffer_size_mve(const cmsis_nn_dims *filter_dims);
arm_cmsis_nn_status arm_lstm_unidirectional_s16(const int16_t *input,
int16_t *output,
const cmsis_nn_lstm_params *params,
cmsis_nn_lstm_context *buffers);

#ifdef __cplusplus
}
Expand Down
109 changes: 105 additions & 4 deletions Include/arm_nnsupportfunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_nnsupportfunctions.h
* Description: Public header file of support functions for CMSIS NN Library
*
* $Date: 22 March 2024
* $Revision: V.20.0.0
* $Date: 14 February 2024
* $Revision: V.20.1.0
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
Expand Down Expand Up @@ -1538,9 +1538,9 @@ __STATIC_FORCEINLINE void arm_nn_write_s8x2_ia(int8_t **dst, int16_t src)

// Support functions for LSTM
/**
* @brief Update LSTM function for an iteration step
* @brief Update LSTM function for an iteration step using s8 input and output, and s16 internally.
*
* @param[in] data_in Data input pointervoid
* @param[in] data_in Data input pointer
* @param[in] hidden_in Hidden state/ recurrent input pointer
* @param[out] hidden_out Hidden state/ recurrent output pointer
* @param[in] params Struct containg all information about the lstm operator, see
Expand All @@ -1561,6 +1561,30 @@ arm_cmsis_nn_status arm_nn_lstm_step_s8(const int8_t *data_in,
cmsis_nn_lstm_context *buffers,
const int32_t batch_offset);

/**
* @brief Update LSTM function for an iteration step using s16 input and output, and s16 internally.
*
* @param[in] data_in Data input pointer
* @param[in] hidden_in Hidden state/ recurrent input pointer
* @param[out] hidden_out Hidden state/ recurrent output pointer
* @param[in] params Struct containg all information about the lstm operator, see
* arm_nn_types.
* @param[in] buffers Struct containg pointers to all temporary scratch buffers needed for the
* lstm operator, see arm_nn_types.
* @param[in] batch_offset Number of timesteps between consecutive batches.
* E.g for params->timing_major = true, all batches for t=0 are stored sequentially, so batch offset = 1.
* For params->time major = false, all time steps are stored continously before the next batch, so
* batch offset = params->time_steps.
* @return The function returns ARM_CMSIS_NN_SUCCESS
*/
arm_cmsis_nn_status arm_nn_lstm_step_s16(const int16_t *data_in,
const int16_t *hidden_in,
int16_t *hidden_out,
const cmsis_nn_lstm_params *params,
cmsis_nn_lstm_context *buffers,
const int32_t batch_offset);

/**
* @brief Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version.
*
Expand All @@ -1582,6 +1606,27 @@ arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s8_s16(const int8_t *data_in,
int16_t *output,
const int32_t batch_offset);

/**
* @brief Updates a LSTM gate for an iteration step of LSTM function, int16x8_16 version.
*
* @param[in] data_in Data input pointer
* @param[in] hidden_in Hidden state/ recurrent input pointer
* @param[in] gate_data Struct containing all information about the gate caluclation, see
* arm_nn_types.
* @param[in] params Struct containing all information about the lstm_operation, see
* arm_nn_types
* @param[out] output Hidden state/ recurrent output pointer
* @param[in] batch_offset Number of timesteps between consecutive batches, see
* arm_nn_lstm_step_s16.
* @return The function returns ARM_CMSIS_NN_SUCCESS
*/
arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s16(const int16_t *data_in,
const int16_t *hidden_in,
const cmsis_nn_lstm_gate *gate_data,
const cmsis_nn_lstm_params *params,
int16_t *output,
const int32_t batch_offset);

/**
* @brief The result of the multiplication is accumulated to the passed result buffer.
* Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
Expand Down Expand Up @@ -1612,6 +1657,36 @@ arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s8_s16(const int8_t *lhs,
const int32_t batches,
const int32_t batch_offset);

/**
* @brief The result of the multiplication is accumulated to the passed result buffer.
* Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
* from each other).
*
* @param[in] lhs Batched vector
* @param[in] rhs Weights - input matrix (H(Rows)xW(Columns))
* @param[in] effective_bias Bias + lhs_offset * kernel_sum term precalculated into a constant vector.
* @param[out] dst Output
* @param[in] dst_multiplier Multiplier for quantization
* @param[in] dst_shift Shift for quantization
* @param[in] rhs_cols Vector/matarix column length
* @param[in] rhs_rows Row count of matrix
* @param[in] batches Batch size
* @param[in] batch_offset Number of timesteps between consecutive batches in input, see arm_nn_lstm_step_s16.
Note that the output is always stored with sequential batches.
* @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
*/
arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s16(const int16_t *lhs,
const int8_t *rhs,
const int64_t *effective_bias,
int16_t *dst,
const int32_t dst_multiplier,
const int32_t dst_shift,
const int32_t rhs_cols,
const int32_t rhs_rows,
const int32_t batches,
const int32_t batch_offset);

/**
* @brief s16 elementwise multiplication with s8 output
* @param[in] input_1_vect pointer to input vector 1
Expand All @@ -1638,6 +1713,32 @@ arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect,
const int32_t batch_size,
const int32_t batch_offset);

/**
* @brief s16 elementwise multiplication with s16 output
* @param[in] input_1_vect pointer to input vector 1
* @param[in] input_2_vect pointer to input vector 2
* @param[in,out] output pointer to output vector
* @param[in] out_offset output offset
* @param[in] out_mult output multiplier
* @param[in] out_shift output shift
* @param[in] block_size number of samples per batch
* @param[in] batch_size number of samples per batch
* @param[in] batch_offset Number of timesteps between consecutive batches in output, see
* arm_nn_lstm_step_s16. Note that it is assumed that the input is stored with sequential batches.
* @return The function returns ARM_CMSIS_NN_SUCCESS
*
* @details Supported framework: TensorFlow Lite micro
*/
arm_cmsis_nn_status arm_elementwise_mul_s16_batch_offset(const int16_t *input_1_vect,
const int16_t *input_2_vect,
int16_t *output,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
const int32_t block_size,
const int32_t batch_size,
const int32_t batch_offset);

/**
* @brief s16 elementwise multiplication. The result of the multiplication is accumulated to the passed result buffer.
* @param[in] input_1_vect pointer to input vector 1
Expand Down
Loading

0 comments on commit 8492d82

Please sign in to comment.