diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index 527b85a8ea8..97d6897b9c8 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -407,6 +407,15 @@ cc_library( copts = micro_copts(), ) +cc_library( + name = "static_vector", + hdrs = ["static_vector.h"], + copts = micro_copts(), + deps = [ + "//tensorflow/lite/kernels:op_macros", + ], +) + cc_library( name = "system_setup", srcs = [ @@ -616,6 +625,18 @@ cc_test( ], ) +cc_test( + name = "static_vector_test", + size = "small", + srcs = [ + "static_vector_test.cc", + ], + deps = [ + ":static_vector", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + bzl_library( name = "build_def_bzl", srcs = ["build_def.bzl"], diff --git a/tensorflow/lite/micro/static_vector.h b/tensorflow/lite/micro/static_vector.h new file mode 100644 index 00000000000..8b9e06392fb --- /dev/null +++ b/tensorflow/lite/micro/static_vector.h @@ -0,0 +1,83 @@ +// Copyright 2024 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_ +#define TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_ + +#include +#include +#include + +#include "tensorflow/lite/kernels/op_macros.h" // for TF_LITE_ASSERT + +namespace tflite { + +template +class StaticVector { + // A staticlly-allocated vector. Add to the interface as needed. + + private: + std::array array_; + std::size_t size_{0}; + + public: + using iterator = typename decltype(array_)::iterator; + using const_iterator = typename decltype(array_)::const_iterator; + using pointer = typename decltype(array_)::pointer; + using reference = typename decltype(array_)::reference; + using const_reference = typename decltype(array_)::const_reference; + + StaticVector() {} + + StaticVector(std::initializer_list values) { + for (const T& v : values) { + push_back(v); + } + } + + static constexpr std::size_t max_size() { return MaxSize; } + std::size_t size() const { return size_; } + bool full() const { return size() == max_size(); } + iterator begin() { return array_.begin(); } + const_iterator begin() const { return array_.begin(); } + iterator end() { return begin() + size(); } + const_iterator end() const { return begin() + size(); } + pointer data() { return array_.data(); } + reference operator[](int i) { return array_[i]; } + const_reference operator[](int i) const { return array_[i]; } + void clear() { size_ = 0; } + + template + bool operator==(const StaticVector& other) const { + return std::equal(begin(), end(), other.begin(), other.end()); + } + + template + bool operator!=(const StaticVector& other) const { + return !(*this == other); + } + + void push_back(const T& t) { + TF_LITE_ASSERT(!full()); + *end() = t; + ++size_; + } +}; + +template +StaticVector(T, U...) -> StaticVector; + +} // end namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_ diff --git a/tensorflow/lite/micro/static_vector_test.cc b/tensorflow/lite/micro/static_vector_test.cc new file mode 100644 index 00000000000..6d601bcf89a --- /dev/null +++ b/tensorflow/lite/micro/static_vector_test.cc @@ -0,0 +1,82 @@ +// Copyright 2024 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/micro/static_vector.h" + +#include "tensorflow/lite/micro/testing/micro_test.h" + +using tflite::StaticVector; + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(StaticVectorPushBack) { + StaticVector a; + TF_LITE_MICRO_EXPECT(a.max_size() == 4); + TF_LITE_MICRO_EXPECT(a.size() == 0); + + a.push_back(1); + TF_LITE_MICRO_EXPECT(a.size() == 1); + TF_LITE_MICRO_EXPECT(a[0] == 1); + + a.push_back(2); + TF_LITE_MICRO_EXPECT(a.size() == 2); + TF_LITE_MICRO_EXPECT(a[1] == 2); + + a.push_back(3); + TF_LITE_MICRO_EXPECT(a.size() == 3); + TF_LITE_MICRO_EXPECT(a[2] == 3); +} + +TF_LITE_MICRO_TEST(StaticVectorInitializationPartial) { + const StaticVector a{1, 2, 3}; + TF_LITE_MICRO_EXPECT(a.max_size() == 4); + TF_LITE_MICRO_EXPECT(a.size() == 3); + TF_LITE_MICRO_EXPECT(a[0] == 1); + TF_LITE_MICRO_EXPECT(a[1] == 2); + TF_LITE_MICRO_EXPECT(a[2] == 3); +} + +TF_LITE_MICRO_TEST(StaticVectorInitializationFull) { + const StaticVector b{1, 2, 3}; + TF_LITE_MICRO_EXPECT(b.max_size() == 3); + TF_LITE_MICRO_EXPECT(b.size() == 3); +} + +TF_LITE_MICRO_TEST(StaticVectorEquality) { + const StaticVector a{1, 2, 3}; + const StaticVector b{1, 2, 3}; + TF_LITE_MICRO_EXPECT(a == b); + TF_LITE_MICRO_EXPECT(!(a != b)); +} + +TF_LITE_MICRO_TEST(StaticVectorInequality) { + const StaticVector a{1, 2, 3}; + const StaticVector b{3, 2, 1}; + TF_LITE_MICRO_EXPECT(a != b); + TF_LITE_MICRO_EXPECT(!(a == b)); +} + +TF_LITE_MICRO_TEST(StaticVectorSizeInequality) { + const StaticVector a{1, 2}; + const StaticVector b{1, 2, 3}; + TF_LITE_MICRO_EXPECT(a != b); +} + +TF_LITE_MICRO_TEST(StaticVectorPartialSizeInequality) { + const StaticVector a{1, 2}; + const StaticVector b{1, 2, 3}; + TF_LITE_MICRO_EXPECT(a != b); +} + +TF_LITE_MICRO_TESTS_END