diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index 97d6897b9c8..1753465425d 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -568,6 +568,18 @@ cc_test( ], ) +cc_test( + name = "span_test", + size = "small", + srcs = [ + "span_test.cc", + ], + deps = [ + ":span", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + cc_test( name = "testing_helpers_test", srcs = [ diff --git a/tensorflow/lite/micro/span.h b/tensorflow/lite/micro/span.h index eeb4c0b6ac5..9399f1de901 100644 --- a/tensorflow/lite/micro/span.h +++ b/tensorflow/lite/micro/span.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_SPAN_H_ #define TENSORFLOW_LITE_MICRO_SPAN_H_ +#include #include namespace tflite { @@ -26,6 +27,13 @@ class Span { public: constexpr Span(T* data, size_t size) noexcept : data_(data), size_(size) {} + template + constexpr Span(T (&data)[N]) noexcept : data_(data), size_(N) {} + + template + constexpr Span(std::array& array) noexcept + : data_(array.data()), size_(N) {} + constexpr T& operator[](size_t idx) const noexcept { return *(data_ + idx); } constexpr T* data() const noexcept { return data_; } @@ -36,6 +44,26 @@ class Span { size_t size_; }; +template +bool operator==(const Span& a, const Span& b) { + if (a.size() != b.size()) { + return false; + } + + for (size_t i = 0; i < a.size(); ++i) { + if (a[i] != b[i]) { + return false; + } + } + + return true; +} + +template +bool operator!=(const Span& a, const Span& b) { + return !(a == b); +} + } // end namespace tflite #endif // TENSORFLOW_LITE_MICRO_SPAN_H_ diff --git a/tensorflow/lite/micro/span_test.cc b/tensorflow/lite/micro/span_test.cc new file mode 100644 index 00000000000..ef906c6df70 --- /dev/null +++ b/tensorflow/lite/micro/span_test.cc @@ -0,0 +1,59 @@ +// Copyright 2024 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/micro/span.h" + +#include + +#include "tensorflow/lite/micro/testing/micro_test.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestArrayInitialization) { + int a[]{1, 2, 3}; + tflite::Span s{a}; + TF_LITE_MICRO_EXPECT(s.data() == a); + TF_LITE_MICRO_EXPECT(s.size() == sizeof(a) / sizeof(int)); +} + +TF_LITE_MICRO_TEST(TestStdArrayInitialization) { + std::array a; + tflite::Span s{a}; + TF_LITE_MICRO_EXPECT(s.data() == a.data()); + TF_LITE_MICRO_EXPECT(s.size() == a.size()); +} + +TF_LITE_MICRO_TEST(TestEquality) { + constexpr int a[]{1, 2, 3}; + constexpr int b[]{1, 2, 3}; + constexpr int c[]{3, 2, 1}; + tflite::Span s_a{a}; + tflite::Span s_b{b}; + tflite::Span s_c{c}; + TF_LITE_MICRO_EXPECT_TRUE(s_a == s_b); + TF_LITE_MICRO_EXPECT_FALSE(s_a == s_c); +} + +TF_LITE_MICRO_TEST(TestInequality) { + constexpr int a[]{1, 2, 3}; + constexpr int b[]{1, 2, 3}; + constexpr int c[]{3, 2, 1}; + tflite::Span s_a{a}; + tflite::Span s_b{b}; + tflite::Span s_c{c}; + TF_LITE_MICRO_EXPECT_FALSE(s_a != s_b); + TF_LITE_MICRO_EXPECT_TRUE(s_a != s_c); +} + +TF_LITE_MICRO_TESTS_END