Skip to content

Commit

Permalink
feat: add fixed-capacity, statically-allocated type tflite::StaticVec…
Browse files Browse the repository at this point in the history
…tor (#2642)

Add a type, tflite::StaticVector, which behaves like std::vector, but
which avoids heap memory allocation.

BUG=#2636
  • Loading branch information
rkuester authored Aug 7, 2024
1 parent 8f9a923 commit d3475aa
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tensorflow/lite/micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,15 @@ cc_library(
copts = micro_copts(),
)

cc_library(
name = "static_vector",
hdrs = ["static_vector.h"],
copts = micro_copts(),
deps = [
"//tensorflow/lite/kernels:op_macros",
],
)

cc_library(
name = "system_setup",
srcs = [
Expand Down Expand Up @@ -616,6 +625,18 @@ cc_test(
],
)

cc_test(
name = "static_vector_test",
size = "small",
srcs = [
"static_vector_test.cc",
],
deps = [
":static_vector",
"//tensorflow/lite/micro/testing:micro_test",
],
)

bzl_library(
name = "build_def_bzl",
srcs = ["build_def.bzl"],
Expand Down
83 changes: 83 additions & 0 deletions tensorflow/lite/micro/static_vector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2024 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_
#define TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_

#include <array>
#include <cassert>
#include <cstddef>

#include "tensorflow/lite/kernels/op_macros.h" // for TF_LITE_ASSERT

namespace tflite {

template <typename T, std::size_t MaxSize>
class StaticVector {
// A staticlly-allocated vector. Add to the interface as needed.

private:
std::array<T, MaxSize> array_;
std::size_t size_{0};

public:
using iterator = typename decltype(array_)::iterator;
using const_iterator = typename decltype(array_)::const_iterator;
using pointer = typename decltype(array_)::pointer;
using reference = typename decltype(array_)::reference;
using const_reference = typename decltype(array_)::const_reference;

StaticVector() {}

StaticVector(std::initializer_list<T> values) {
for (const T& v : values) {
push_back(v);
}
}

static constexpr std::size_t max_size() { return MaxSize; }
std::size_t size() const { return size_; }
bool full() const { return size() == max_size(); }
iterator begin() { return array_.begin(); }
const_iterator begin() const { return array_.begin(); }
iterator end() { return begin() + size(); }
const_iterator end() const { return begin() + size(); }
pointer data() { return array_.data(); }
reference operator[](int i) { return array_[i]; }
const_reference operator[](int i) const { return array_[i]; }
void clear() { size_ = 0; }

template <std::size_t N>
bool operator==(const StaticVector<T, N>& other) const {
return std::equal(begin(), end(), other.begin(), other.end());
}

template <std::size_t N>
bool operator!=(const StaticVector<T, N>& other) const {
return !(*this == other);
}

void push_back(const T& t) {
TF_LITE_ASSERT(!full());
*end() = t;
++size_;
}
};

template <typename T, typename... U>
StaticVector(T, U...) -> StaticVector<T, 1 + sizeof...(U)>;

} // end namespace tflite

#endif // TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_
82 changes: 82 additions & 0 deletions tensorflow/lite/micro/static_vector_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2024 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "tensorflow/lite/micro/static_vector.h"

#include "tensorflow/lite/micro/testing/micro_test.h"

using tflite::StaticVector;

TF_LITE_MICRO_TESTS_BEGIN

TF_LITE_MICRO_TEST(StaticVectorPushBack) {
StaticVector<int, 4> a;
TF_LITE_MICRO_EXPECT(a.max_size() == 4);
TF_LITE_MICRO_EXPECT(a.size() == 0);

a.push_back(1);
TF_LITE_MICRO_EXPECT(a.size() == 1);
TF_LITE_MICRO_EXPECT(a[0] == 1);

a.push_back(2);
TF_LITE_MICRO_EXPECT(a.size() == 2);
TF_LITE_MICRO_EXPECT(a[1] == 2);

a.push_back(3);
TF_LITE_MICRO_EXPECT(a.size() == 3);
TF_LITE_MICRO_EXPECT(a[2] == 3);
}

TF_LITE_MICRO_TEST(StaticVectorInitializationPartial) {
const StaticVector<int, 4> a{1, 2, 3};
TF_LITE_MICRO_EXPECT(a.max_size() == 4);
TF_LITE_MICRO_EXPECT(a.size() == 3);
TF_LITE_MICRO_EXPECT(a[0] == 1);
TF_LITE_MICRO_EXPECT(a[1] == 2);
TF_LITE_MICRO_EXPECT(a[2] == 3);
}

TF_LITE_MICRO_TEST(StaticVectorInitializationFull) {
const StaticVector b{1, 2, 3};
TF_LITE_MICRO_EXPECT(b.max_size() == 3);
TF_LITE_MICRO_EXPECT(b.size() == 3);
}

TF_LITE_MICRO_TEST(StaticVectorEquality) {
const StaticVector a{1, 2, 3};
const StaticVector b{1, 2, 3};
TF_LITE_MICRO_EXPECT(a == b);
TF_LITE_MICRO_EXPECT(!(a != b));
}

TF_LITE_MICRO_TEST(StaticVectorInequality) {
const StaticVector a{1, 2, 3};
const StaticVector b{3, 2, 1};
TF_LITE_MICRO_EXPECT(a != b);
TF_LITE_MICRO_EXPECT(!(a == b));
}

TF_LITE_MICRO_TEST(StaticVectorSizeInequality) {
const StaticVector a{1, 2};
const StaticVector b{1, 2, 3};
TF_LITE_MICRO_EXPECT(a != b);
}

TF_LITE_MICRO_TEST(StaticVectorPartialSizeInequality) {
const StaticVector<int, 3> a{1, 2};
const StaticVector<int, 3> b{1, 2, 3};
TF_LITE_MICRO_EXPECT(a != b);
}

TF_LITE_MICRO_TESTS_END

0 comments on commit d3475aa

Please sign in to comment.