diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 56d6fe019..f327fa406 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -7,9 +7,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_jar") def base_deps(): """Base evaluator and test dependencies.""" - # Abseil LTS 20240116.2 - ABSL_SHA1 = "d7aaad83b488fd62bd51c81ecf16cd938532cc0a" - ABSL_SHA256 = "68e7d36d621769ab500b2ebeec6a7910420566874b4b33b340a04bd70e67fe43" + # Abseil LTS 20240722.0 + ABSL_SHA1 = "4447c7562e3bc702ade25105912dce503f0c4010" + ABSL_SHA256 = "d8342ad77aa9e16103c486b615460c24a695a1f04cdb760eb02fef780df99759" http_archive( name = "com_google_absl", urls = ["https://github.com/abseil/abseil-cpp/archive/" + ABSL_SHA1 + ".zip"], @@ -17,9 +17,9 @@ def base_deps(): sha256 = ABSL_SHA256, ) - # v1.14.0 - GOOGLETEST_SHA1 = "f8d7d77c06936315286eb55f8de22cd23c188571" - GOOGLETEST_SHA256 = "b976cf4fd57b318afdb1bdb27fc708904b3e4bed482859eb94ba2b4bdd077fe2" + # v1.15.2 + GOOGLETEST_SHA1 = "b514bdc898e2951020cbdca1304b75f5950d1f59" + GOOGLETEST_SHA256 = "8c0ceafa3ea24bf78e3519b7846d99e76c45899aa4dac4d64e7dd62e495de9fd" http_archive( name = "com_google_googletest", urls = ["https://github.com/google/googletest/archive/" + GOOGLETEST_SHA1 + ".zip"], diff --git a/common/BUILD b/common/BUILD index 891f8c68f..8f5bb9ce3 100644 --- a/common/BUILD +++ b/common/BUILD @@ -913,3 +913,162 @@ cc_library( "@com_google_absl//absl/utility", ], ) + +cc_library( + name = "arena_string", + hdrs = ["arena_string.h"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "arena_string_test", + srcs = ["arena_string_test.cc"], + deps = [ + ":arena_string", + "//internal:testing", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "arena_string_pool", + hdrs = ["arena_string_pool.h"], + deps = [ + ":arena_string", + "//internal:string_pool", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "arena_string_pool_test", + srcs = ["arena_string_pool_test.cc"], + deps = [ + ":arena_string_pool", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_pool", + srcs = ["type_pool.cc"], + hdrs = ["type_pool.h"], + deps = [ + ":arena_string", + ":arena_string_pool", + ":type", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_pool_test", + srcs = ["type_pool_test.cc"], + deps = [ + ":arena_string_pool", + ":type", + ":type_pool", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_proto", + srcs = ["type_proto.cc"], + hdrs = ["type_proto.h"], + deps = [ + ":type", + ":type_kind", + ":type_pool", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_proto_test", + srcs = ["type_proto_test.cc"], + deps = [ + ":arena_string_pool", + ":type", + ":type_pool", + ":type_proto", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_proto_v1alpha1", + srcs = ["type_proto_v1alpha1.cc"], + hdrs = ["type_proto_v1alpha1.h"], + deps = [ + ":type", + ":type_kind", + ":type_pool", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_proto_v1alpha1_test", + srcs = ["type_proto_v1alpha1_test.cc"], + deps = [ + ":arena_string_pool", + ":type", + ":type_pool", + ":type_proto_v1alpha1", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/common/arena_string.h b/common/arena_string.h new file mode 100644 index 000000000..f1b58164f --- /dev/null +++ b/common/arena_string.h @@ -0,0 +1,254 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/base/macros.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" + +namespace cel { + +class ArenaStringPool; + +// Bug in current Abseil LTS. Fixed in +// https://github.com/abseil/abseil-cpp/commit/fd7713cb9a97c49096211ff40de280b6cebbb21c +// which is not yet in an LTS. +#if defined(__clang__) && (!defined(__clang_major__) || __clang_major__ >= 13) +#define CEL_ATTRIBUTE_ARENA_STRING_VIEW ABSL_ATTRIBUTE_VIEW +#else +#define CEL_ATTRIBUTE_ARENA_STRING_VIEW +#endif + +// `ArenaString` is a read-only string which is either backed by a static string +// literal or owned by the `ArenaStringPool` that created it. It is compatible +// with `absl::string_view` and is implicitly convertible to it. +class CEL_ATTRIBUTE_ARENA_STRING_VIEW ArenaString final { + private: + template + static constexpr bool IsStringLiteral(const char (&string)[N]) { + static_assert(N > 0); + for (size_t i = 0; i < N - 1; ++i) { + if (string[i] == '\0') { + return false; + } + } + return string[N - 1] == '\0'; + } + + public: + using traits_type = std::char_traits; + using value_type = char; + using pointer = char*; + using const_pointer = const char*; + using reference = char&; + using const_reference = const char&; + using const_iterator = const_pointer; + using iterator = const_iterator; + using const_reverse_iterator = std::reverse_iterator; + using reverse_iterator = const_reverse_iterator; + using size_type = uint32_t; + using difference_type = int32_t; + + using absl_internal_is_view = std::true_type; + + template + static constexpr ArenaString Static(const char (&string)[N]) +#if ABSL_HAVE_ATTRIBUTE(enable_if) + __attribute__((enable_if(ArenaString::IsStringLiteral(string), + "chosen when 'string' is a string literal"))) +#endif + { + static_assert(N > 0); + static_assert(N - 1 <= std::numeric_limits::max()); + return ArenaString(string); + } + + ArenaString() = default; + ArenaString(const ArenaString&) = default; + ArenaString& operator=(const ArenaString&) = default; + + constexpr size_type size() const { return size_; } + + constexpr bool empty() const { return size() == 0; } + + constexpr size_type max_size() const { + return std::numeric_limits::max(); + } + + constexpr absl::Nonnull data() const { return data_; } + + constexpr const_reference front() const { + ABSL_ASSERT(!empty()); + return data()[0]; + } + + constexpr const_reference back() const { + ABSL_ASSERT(!empty()); + return data()[size() - 1]; + } + + constexpr const_reference operator[](size_type index) const { + ABSL_ASSERT(index < size()); + return data()[index]; + } + + constexpr void remove_prefix(size_type n) { + ABSL_ASSERT(n <= size()); + data_ += n; + size_ -= n; + } + + constexpr void remove_suffix(size_type n) { + ABSL_ASSERT(n <= size()); + size_ -= n; + } + + constexpr const_iterator begin() const { return data(); } + + constexpr const_iterator cbegin() const { return begin(); } + + constexpr const_iterator end() const { return data() + size(); } + + constexpr const_iterator cend() const { return end(); } + + constexpr const_reverse_iterator rbegin() const { + return std::make_reverse_iterator(end()); + } + + constexpr const_reverse_iterator crbegin() const { return rbegin(); } + + constexpr const_reverse_iterator rend() const { + return std::make_reverse_iterator(begin()); + } + + constexpr const_reverse_iterator crend() const { return rend(); } + + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr operator absl::string_view() const { + return absl::string_view(data(), size()); + } + + private: + friend class ArenaStringPool; + + constexpr explicit ArenaString(absl::string_view value) + : data_(value.data()), size_(static_cast(value.size())) { + ABSL_ASSERT(value.data() != nullptr); + ABSL_ASSERT(value.size() <= max_size()); + } + + absl::Nonnull data_ = ""; + size_type size_ = 0; +}; + +constexpr bool operator==(ArenaString lhs, ArenaString rhs) { + return absl::implicit_cast(lhs) == + absl::implicit_cast(rhs); +} + +constexpr bool operator==(ArenaString lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) == rhs; +} + +constexpr bool operator==(absl::string_view lhs, ArenaString rhs) { + return lhs == absl::implicit_cast(rhs); +} + +constexpr bool operator!=(ArenaString lhs, ArenaString rhs) { + return absl::implicit_cast(lhs) != + absl::implicit_cast(rhs); +} + +constexpr bool operator!=(ArenaString lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) != rhs; +} + +constexpr bool operator!=(absl::string_view lhs, ArenaString rhs) { + return lhs != absl::implicit_cast(rhs); +} + +constexpr bool operator<(ArenaString lhs, ArenaString rhs) { + return absl::implicit_cast(lhs) < + absl::implicit_cast(rhs); +} + +constexpr bool operator<(ArenaString lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) < rhs; +} + +constexpr bool operator<(absl::string_view lhs, ArenaString rhs) { + return lhs < absl::implicit_cast(rhs); +} + +constexpr bool operator<=(ArenaString lhs, ArenaString rhs) { + return absl::implicit_cast(lhs) <= + absl::implicit_cast(rhs); +} + +constexpr bool operator<=(ArenaString lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) <= rhs; +} + +constexpr bool operator<=(absl::string_view lhs, ArenaString rhs) { + return lhs <= absl::implicit_cast(rhs); +} + +constexpr bool operator>(ArenaString lhs, ArenaString rhs) { + return absl::implicit_cast(lhs) > + absl::implicit_cast(rhs); +} + +constexpr bool operator>(ArenaString lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) > rhs; +} + +constexpr bool operator>(absl::string_view lhs, ArenaString rhs) { + return lhs > absl::implicit_cast(rhs); +} + +constexpr bool operator>=(ArenaString lhs, ArenaString rhs) { + return absl::implicit_cast(lhs) >= + absl::implicit_cast(rhs); +} + +constexpr bool operator>=(ArenaString lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) >= rhs; +} + +constexpr bool operator>=(absl::string_view lhs, ArenaString rhs) { + return lhs >= absl::implicit_cast(rhs); +} + +template +H AbslHashValue(H state, ArenaString arena_string) { + return H::combine(std::move(state), + absl::implicit_cast(arena_string)); +} + +#undef CEL_ATTRIBUTE_ARENA_STRING_VIEW + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ diff --git a/common/arena_string_pool.h b/common/arena_string_pool.h new file mode 100644 index 000000000..97de1334a --- /dev/null +++ b/common/arena_string_pool.h @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "common/arena_string.h" +#include "internal/string_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class ArenaStringPool; + +absl::Nonnull> NewArenaStringPool( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class ArenaStringPool final { + public: + ArenaStringPool(const ArenaStringPool&) = delete; + ArenaStringPool(ArenaStringPool&&) = delete; + ArenaStringPool& operator=(const ArenaStringPool&) = delete; + ArenaStringPool& operator=(ArenaStringPool&&) = delete; + + ArenaString InternString(absl::string_view string) { + return ArenaString(strings_.InternString(string)); + } + + ArenaString InternString(ArenaString) = delete; + + private: + friend absl::Nonnull> NewArenaStringPool( + absl::Nonnull); + + explicit ArenaStringPool(absl::Nonnull arena) + : strings_(arena) {} + + internal::StringPool strings_; +}; + +inline absl::Nonnull> NewArenaStringPool( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return std::unique_ptr(new ArenaStringPool(arena)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ diff --git a/common/arena_string_pool_test.cc b/common/arena_string_pool_test.cc new file mode 100644 index 000000000..dda0fa864 --- /dev/null +++ b/common/arena_string_pool_test.cc @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/arena_string_pool.h" + +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +TEST(ArenaStringPool, InternString) { + google::protobuf::Arena arena; + auto string_pool = NewArenaStringPool(&arena); + auto expected = string_pool->InternString("Hello World!"); + auto got = string_pool->InternString("Hello World!"); + EXPECT_EQ(expected.data(), got.data()); +} + +} // namespace +} // namespace cel diff --git a/common/arena_string_test.cc b/common/arena_string_test.cc new file mode 100644 index 000000000..9d80b9828 --- /dev/null +++ b/common/arena_string_test.cc @@ -0,0 +1,126 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/arena_string.h" + +#include "absl/hash/hash.h" +#include "absl/hash/hash_testing.h" +#include "absl/strings/string_view.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using testing::Eq; +using testing::Ge; +using testing::Gt; +using testing::IsEmpty; +using testing::Le; +using testing::Lt; +using testing::Ne; +using testing::SizeIs; + +TEST(ArenaString, Default) { + ArenaString string; + EXPECT_THAT(string, IsEmpty()); + EXPECT_THAT(string, SizeIs(0)); + EXPECT_THAT(string, Eq(ArenaString())); +} + +TEST(ArenaString, Iterator) { + ArenaString string = ArenaString::Static("Hello World!"); + auto it = string.cbegin(); + EXPECT_THAT(*it++, Eq('H')); + EXPECT_THAT(*it++, Eq('e')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq(' ')); + EXPECT_THAT(*it++, Eq('W')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('r')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('d')); + EXPECT_THAT(*it++, Eq('!')); + EXPECT_THAT(it, Eq(string.cend())); +} + +TEST(ArenaString, ReverseIterator) { + ArenaString string = ArenaString::Static("Hello World!"); + auto it = string.crbegin(); + EXPECT_THAT(*it++, Eq('!')); + EXPECT_THAT(*it++, Eq('d')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('r')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('W')); + EXPECT_THAT(*it++, Eq(' ')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('e')); + EXPECT_THAT(*it++, Eq('H')); + EXPECT_THAT(it, Eq(string.crend())); +} + +TEST(ArenaString, RemovePrefix) { + ArenaString string = ArenaString::Static("Hello World!"); + string.remove_prefix(6); + EXPECT_EQ(string, "World!"); +} + +TEST(ArenaString, RemoveSuffix) { + ArenaString string = ArenaString::Static("Hello World!"); + string.remove_suffix(7); + EXPECT_EQ(string, "Hello"); +} + +TEST(ArenaString, Equal) { + EXPECT_THAT(ArenaString::Static("1"), Eq(ArenaString::Static("1"))); +} + +TEST(ArenaString, NotEqual) { + EXPECT_THAT(ArenaString::Static("1"), Ne(ArenaString::Static("2"))); +} + +TEST(ArenaString, Less) { + EXPECT_THAT(ArenaString::Static("1"), Lt(ArenaString::Static("2"))); +} + +TEST(ArenaString, LessEqual) { + EXPECT_THAT(ArenaString::Static("1"), Le(ArenaString::Static("1"))); +} + +TEST(ArenaString, Greater) { + EXPECT_THAT(ArenaString::Static("2"), Gt(ArenaString::Static("1"))); +} + +TEST(ArenaString, GreaterEqual) { + EXPECT_THAT(ArenaString::Static("1"), Ge(ArenaString::Static("1"))); +} + +TEST(ArenaString, ImplementsAbslHashCorrectly) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {ArenaString::Static(""), ArenaString::Static("Hello World!"), + ArenaString::Static("How much wood could a woodchuck chuck if a " + "woodchuck could chuck wood?")})); +} + +TEST(ArenaString, Hash) { + EXPECT_EQ(absl::HashOf(ArenaString::Static("Hello World!")), + absl::HashOf(absl::string_view("Hello World!"))); +} + +} // namespace +} // namespace cel diff --git a/common/type.h b/common/type.h index 9d849dc60..3652a89d4 100644 --- a/common/type.h +++ b/common/type.h @@ -860,6 +860,20 @@ class TypeParameters final { }; }; +inline bool operator==(const TypeParameters& lhs, const TypeParameters& rhs) { + return absl::c_equal(lhs, rhs); +} + +inline bool operator!=(const TypeParameters& lhs, const TypeParameters& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const TypeParameters& parameters) { + return H::combine_contiguous(std::move(state), parameters.data(), + parameters.size()); +} + // Now that TypeParameters is defined, we can define `GetParameters()` for most // types. diff --git a/common/type_pool.cc b/common/type_pool.cc new file mode 100644 index 000000000..df98d8bbd --- /dev/null +++ b/common/type_pool.cc @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_pool.h" + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +ListType TypePool::MakeListType(const Type& element) { + return list_type_pool_.InternListType(element); +} + +MapType TypePool::MakeMapType(const Type& key, const Type& value) { + return map_type_pool_.InternMapType(key, value); +} + +StructType TypePool::MakeStructType(absl::string_view name) { + if (descriptor_pool_ != nullptr) { + const google::protobuf::Descriptor* descriptor = + descriptor_pool_->FindMessageTypeByName(name); + if (descriptor != nullptr) { + return MessageType(descriptor); + } + } + return common_internal::MakeBasicStructType(string_pool_->InternString(name)); +} + +FunctionType TypePool::MakeFunctionType(const Type& result, + absl::Span args) { + return function_type_pool_.InternFunctionType(result, args); +} + +OpaqueType TypePool::MakeOpaqueType(absl::string_view name, + absl::Span params) { + return opaque_type_pool_.InternOpaqueType(string_pool_->InternString(name), + params); +} + +TypeParamType TypePool::MakeTypeParamType(absl::string_view name) { + return TypeParamType(string_pool_->InternString(name)); +} + +TypeType TypePool::MakeTypeType(const Type& type) { + return type_type_pool_.InternTypeType(type); +} + +} // namespace cel diff --git a/common/type_pool.h b/common/type_pool.h new file mode 100644 index 000000000..e20915988 --- /dev/null +++ b/common/type_pool.h @@ -0,0 +1,115 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_POOL_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/arena_string.h" +#include "common/arena_string_pool.h" +#include "common/type.h" +#include "common/types/function_type_pool.h" +#include "common/types/list_type_pool.h" +#include "common/types/map_type_pool.h" +#include "common/types/opaque_type_pool.h" +#include "common/types/type_type_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class TypePool; + +absl::Nonnull> NewTypePool( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nonnull string_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nullable descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class TypePool final { + public: + TypePool(const TypePool&) = delete; + TypePool(TypePool&&) = delete; + TypePool& operator=(const TypePool&) = delete; + TypePool& operator=(TypePool&&) = delete; + + ListType MakeListType(const Type& element); + + MapType MakeMapType(const Type& key, const Type& value); + + StructType MakeStructType(absl::string_view name); + + StructType MakeStructType(ArenaString) = delete; + + FunctionType MakeFunctionType(const Type& result, + absl::Span args); + + OpaqueType MakeOpaqueType(absl::string_view name, + absl::Span params); + + OpaqueType MakeOpaqueType(ArenaString, absl::Span) = delete; + + OptionalType MakeOptionalType(const Type& param) { + return static_cast( + MakeOpaqueType(OptionalType::kName, absl::MakeConstSpan(¶m, 1))); + } + + TypeParamType MakeTypeParamType(absl::string_view name); + + TypeParamType MakeTypeParamType(ArenaString) = delete; + + TypeType MakeTypeType(const Type& type); + + private: + friend absl::Nonnull> NewTypePool( + absl::Nonnull, absl::Nonnull, + absl::Nullable); + + TypePool(absl::Nonnull arena, + absl::Nonnull string_pool, + absl::Nullable descriptor_pool) + : string_pool_(string_pool), + descriptor_pool_(descriptor_pool), + function_type_pool_(arena), + list_type_pool_(arena), + map_type_pool_(arena), + opaque_type_pool_(arena), + type_type_pool_(arena) {} + + absl::Nonnull const string_pool_; + absl::Nullable const descriptor_pool_; + common_internal::FunctionTypePool function_type_pool_; + common_internal::ListTypePool list_type_pool_; + common_internal::MapTypePool map_type_pool_; + common_internal::OpaqueTypePool opaque_type_pool_; + common_internal::TypeTypePool type_type_pool_; +}; + +inline absl::Nonnull> NewTypePool( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nonnull string_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nullable descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return std::unique_ptr( + new TypePool(arena, string_pool, descriptor_pool)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_POOL_H_ diff --git a/common/type_pool_test.cc b/common/type_pool_test.cc new file mode 100644 index 000000000..e14ebe0ef --- /dev/null +++ b/common/type_pool_test.cc @@ -0,0 +1,113 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_pool.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/types/optional.h" +#include "common/arena_string_pool.h" +#include "common/type.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::cel::internal::GetTestingDescriptorPool; +using testing::_; +using testing::Test; + +class TypePoolTest : public Test { + public: + void SetUp() override { + arena_.emplace(); + string_pool_ = NewArenaStringPool(arena()); + type_pool_ = + NewTypePool(arena(), string_pool(), GetTestingDescriptorPool()); + } + + void TearDown() override { + type_pool_.reset(); + string_pool_.reset(); + arena_.reset(); + } + + absl::Nonnull arena() { return &*arena_; } + + absl::Nonnull string_pool() { return string_pool_.get(); } + + absl::Nonnull type_pool() { return type_pool_.get(); } + + private: + absl::optional arena_; + std::unique_ptr string_pool_; + std::unique_ptr type_pool_; +}; + +TEST_F(TypePoolTest, MakeStructType) { + EXPECT_EQ(type_pool()->MakeStructType("foo.Bar"), + common_internal::MakeBasicStructType("foo.Bar")); + EXPECT_TRUE( + type_pool() + ->MakeStructType("google.api.expr.test.v1.proto3.TestAllTypes") + .IsMessage()); + EXPECT_DEBUG_DEATH(static_cast(type_pool()->MakeStructType( + "google.protobuf.BoolValue")), + _); +} + +TEST_F(TypePoolTest, MakeFunctionType) { + EXPECT_EQ(type_pool()->MakeFunctionType(BoolType(), {IntType(), IntType()}), + FunctionType(arena(), BoolType(), {IntType(), IntType()})); +} + +TEST_F(TypePoolTest, MakeListType) { + EXPECT_EQ(type_pool()->MakeListType(DynType()), ListType()); + EXPECT_EQ(type_pool()->MakeListType(DynType()), JsonListType()); + EXPECT_EQ(type_pool()->MakeListType(StringType()), + ListType(arena(), StringType())); +} + +TEST_F(TypePoolTest, MakeMapType) { + EXPECT_EQ(type_pool()->MakeMapType(DynType(), DynType()), MapType()); + EXPECT_EQ(type_pool()->MakeMapType(StringType(), DynType()), JsonMapType()); + EXPECT_EQ(type_pool()->MakeMapType(StringType(), StringType()), + MapType(arena(), StringType(), StringType())); +} + +TEST_F(TypePoolTest, MakeOpaqueType) { + EXPECT_EQ(type_pool()->MakeOpaqueType("custom_type", {DynType(), DynType()}), + OpaqueType(arena(), "custom_type", {DynType(), DynType()})); +} + +TEST_F(TypePoolTest, MakeOptionalType) { + EXPECT_EQ(type_pool()->MakeOptionalType(DynType()), OptionalType()); + EXPECT_EQ(type_pool()->MakeOptionalType(StringType()), + OptionalType(arena(), StringType())); +} + +TEST_F(TypePoolTest, MakeTypeParamType) { + EXPECT_EQ(type_pool()->MakeTypeParamType("T"), TypeParamType("T")); +} + +TEST_F(TypePoolTest, MakeTypeType) { + EXPECT_EQ(type_pool()->MakeTypeType(BoolType()), + TypeType(arena(), BoolType())); +} + +} // namespace +} // namespace cel diff --git a/common/type_proto.cc b/common/type_proto.cc new file mode 100644 index 000000000..c6afe869c --- /dev/null +++ b/common/type_proto.cc @@ -0,0 +1,353 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto.h" + +#include + +#include "cel/expr/checked.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "common/type_pool.h" + +namespace cel { + +namespace { + +using TypeProto = ::cel::expr::Type; +using ListTypeProto = typename TypeProto::ListType; +using MapTypeProto = typename TypeProto::MapType; +using FunctionTypeProto = typename TypeProto::FunctionType; +using OpaqueTypeProto = typename TypeProto::AbstractType; +using PrimitiveTypeProto = typename TypeProto::PrimitiveType; +using WellKnownTypeProto = typename TypeProto::WellKnownType; + +struct TypeFromProtoConverter final { + explicit TypeFromProtoConverter(absl::Nonnull type_pool) + : type_pool(type_pool) {} + + absl::optional FromType(const TypeProto& proto) { + switch (proto.type_kind_case()) { + case TypeProto::TYPE_KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case TypeProto::kDyn: + return DynType(); + case TypeProto::kNull: + return NullType(); + case TypeProto::kPrimitive: + switch (proto.primitive()) { + case TypeProto::BOOL: + return BoolType(); + case TypeProto::INT64: + return IntType(); + case TypeProto::UINT64: + return UintType(); + case TypeProto::DOUBLE: + return DoubleType(); + case TypeProto::STRING: + return StringType(); + case TypeProto::BYTES: + return BytesType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected primitive type kind: ", proto.primitive())); + return absl::nullopt; + } + case TypeProto::kWrapper: + switch (proto.wrapper()) { + case TypeProto::BOOL: + return BoolWrapperType(); + case TypeProto::INT64: + return IntWrapperType(); + case TypeProto::UINT64: + return UintWrapperType(); + case TypeProto::DOUBLE: + return DoubleWrapperType(); + case TypeProto::STRING: + return StringWrapperType(); + case TypeProto::BYTES: + return BytesWrapperType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected wrapper type kind: ", proto.wrapper())); + return absl::nullopt; + } + case TypeProto::kWellKnown: + switch (proto.well_known()) { + case TypeProto::ANY: + return AnyType(); + case TypeProto::DURATION: + return DurationType(); + case TypeProto::TIMESTAMP: + return TimestampType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected well known type kind: ", proto.well_known())); + return absl::nullopt; + } + case TypeProto::kListType: { + auto elem = FromType(proto.list_type().elem_type()); + if (ABSL_PREDICT_FALSE(!elem.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeListType(*elem); + } + case TypeProto::kMapType: { + auto key = FromType(proto.map_type().key_type()); + if (ABSL_PREDICT_FALSE(!key.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + auto value = FromType(proto.map_type().value_type()); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeMapType(*key, *value); + } + case TypeProto::kFunction: { + auto result = FromType(proto.function().result_type()); + if (ABSL_PREDICT_FALSE(!result.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + absl::InlinedVector args; + args.reserve(static_cast(proto.function().arg_types().size())); + for (const auto& arg_proto : proto.function().arg_types()) { + auto arg = FromType(arg_proto); + if (ABSL_PREDICT_FALSE(!arg.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + args.push_back(*arg); + } + return type_pool->MakeFunctionType(*result, args); + } + case TypeProto::kMessageType: + if (ABSL_PREDICT_FALSE(proto.message_type().empty())) { + status = + absl::InvalidArgumentError("unexpected empty message type name"); + return absl::nullopt; + } + if (ABSL_PREDICT_FALSE(IsWellKnownMessageType(proto.message_type()))) { + status = absl::InvalidArgumentError( + absl::StrCat("well known type masquerading as message type: ", + proto.message_type())); + return absl::nullopt; + } + return type_pool->MakeStructType(proto.message_type()); + case TypeProto::kTypeParam: + if (ABSL_PREDICT_FALSE(proto.type_param().empty())) { + status = + absl::InvalidArgumentError("unexpected empty type param name"); + return absl::nullopt; + } + return type_pool->MakeTypeParamType(proto.type_param()); + case TypeProto::kType: { + auto type = FromType(proto.type()); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeTypeType(*type); + } + case TypeProto::kError: + return ErrorType(); + case TypeProto::kAbstractType: { + if (proto.abstract_type().name().empty()) { + status = + absl::InvalidArgumentError("unexpected empty opaque type name"); + return absl::nullopt; + } + absl::InlinedVector params; + params.reserve(static_cast( + proto.abstract_type().parameter_types().size())); + for (const auto& param_proto : + proto.abstract_type().parameter_types()) { + auto param = FromType(param_proto); + if (ABSL_PREDICT_FALSE(!param.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + params.push_back(*param); + } + return type_pool->MakeOpaqueType(proto.abstract_type().name(), params); + } + default: + status = absl::DataLossError(absl::StrCat("unexpected type kind case: ", + proto.type_kind_case())); + return absl::nullopt; + } + } + + absl::Nonnull const type_pool; + absl::Status status; +}; + +} // namespace + +absl::StatusOr TypeFromProto(absl::Nonnull type_pool, + const TypeProto& proto) { + TypeFromProtoConverter converter(type_pool); + auto type = converter.FromType(proto); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + ABSL_DCHECK(!converter.status.ok()); + return converter.status; + } + return *type; +} + +namespace { + +struct TypeToProtoConverter final { + bool FromType(const Type& type, absl::Nonnull proto) { + switch (type.kind()) { + case TypeKind::kDyn: + proto->mutable_dyn(); + return true; + case TypeKind::kNull: + proto->set_null(google::protobuf::NULL_VALUE); + return true; + case TypeKind::kBool: + proto->set_primitive(TypeProto::BOOL); + return true; + case TypeKind::kInt: + proto->set_primitive(TypeProto::INT64); + return true; + case TypeKind::kUint: + proto->set_primitive(TypeProto::UINT64); + return true; + case TypeKind::kDouble: + proto->set_primitive(TypeProto::DOUBLE); + return true; + case TypeKind::kBytes: + proto->set_primitive(TypeProto::BYTES); + return true; + case TypeKind::kString: + proto->set_primitive(TypeProto::STRING); + return true; + case TypeKind::kBoolWrapper: + proto->set_wrapper(TypeProto::BOOL); + return true; + case TypeKind::kIntWrapper: + proto->set_wrapper(TypeProto::INT64); + return true; + case TypeKind::kUintWrapper: + proto->set_wrapper(TypeProto::UINT64); + return true; + case TypeKind::kDoubleWrapper: + proto->set_wrapper(TypeProto::DOUBLE); + return true; + case TypeKind::kBytesWrapper: + proto->set_wrapper(TypeProto::BYTES); + return true; + case TypeKind::kStringWrapper: + proto->set_wrapper(TypeProto::STRING); + return true; + case TypeKind::kAny: + proto->set_well_known(TypeProto::ANY); + return true; + case TypeKind::kDuration: + proto->set_well_known(TypeProto::DURATION); + return true; + case TypeKind::kTimestamp: + proto->set_well_known(TypeProto::TIMESTAMP); + return true; + case TypeKind::kList: + return FromType(static_cast(type).GetElement(), + proto->mutable_list_type()->mutable_elem_type()); + case TypeKind::kMap: + return FromType(static_cast(type).GetKey(), + proto->mutable_map_type()->mutable_key_type()) && + FromType(static_cast(type).GetValue(), + proto->mutable_map_type()->mutable_value_type()); + case TypeKind::kStruct: + proto->set_message_type(static_cast(type).name()); + return true; + case TypeKind::kOpaque: { + auto opaque_type = static_cast(type); + auto* opaque_type_proto = proto->mutable_abstract_type(); + opaque_type_proto->set_name(opaque_type.name()); + auto opaque_type_params = opaque_type.GetParameters(); + opaque_type_proto->mutable_parameter_types()->Reserve( + static_cast(opaque_type_params.size())); + for (const auto& param : opaque_type_params) { + if (ABSL_PREDICT_FALSE( + !FromType(param, opaque_type_proto->add_parameter_types()))) { + return false; + } + } + return true; + } + case TypeKind::kTypeParam: + proto->set_type_param(static_cast(type).name()); + return true; + case TypeKind::kType: + return FromType(static_cast(type).GetType(), + proto->mutable_type()); + case TypeKind::kFunction: { + auto function_type = static_cast(type); + auto* function_type_proto = proto->mutable_function(); + if (ABSL_PREDICT_FALSE( + !FromType(function_type.result(), + function_type_proto->mutable_result_type()))) { + return false; + } + auto function_type_args = function_type.args(); + function_type_proto->mutable_arg_types()->Reserve( + static_cast(function_type_args.size())); + for (const auto& arg : function_type_args) { + if (ABSL_PREDICT_FALSE( + !FromType(arg, function_type_proto->add_arg_types()))) { + return false; + } + } + return true; + } + case TypeKind::kError: + proto->mutable_error(); + return true; + default: + status = absl::DataLossError( + absl::StrCat("unexpected type kind: ", type.kind())); + return false; + } + } + + absl::Status status; +}; + +} // namespace + +absl::Status TypeToProto(const Type& type, absl::Nonnull proto) { + TypeToProtoConverter converter; + if (ABSL_PREDICT_FALSE(!converter.FromType(type, proto))) { + ABSL_DCHECK(!converter.status.ok()); + return converter.status; + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/type_proto.h b/common/type_proto.h new file mode 100644 index 000000000..8b32e64a2 --- /dev/null +++ b/common/type_proto.h @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ + +#include "cel/expr/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/type.h" +#include "common/type_pool.h" + +namespace cel { + +// TypeFromProto converts `cel::expr::Type` to `cel::Type`. +absl::StatusOr TypeFromProto(absl::Nonnull type_pool, + const cel::expr::Type& proto); + +// TypeToProto converts `cel::Type` to `cel::expr::Type`. +absl::Status TypeToProto(const Type& type, + absl::Nonnull proto); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ diff --git a/common/type_proto_test.cc b/common/type_proto_test.cc new file mode 100644 index 000000000..84003253e --- /dev/null +++ b/common/type_proto_test.cc @@ -0,0 +1,387 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto.h" + +#include + +#include "cel/expr/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/types/optional.h" +#include "common/arena_string_pool.h" +#include "common/type.h" +#include "common/type_pool.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::cel::internal::test::EqualsProto; +using testing::Eq; +using testing::Test; + +using TypeProto = ::cel::expr::Type; + +class TypeProtoTest : public Test { + public: + void SetUp() override { + arena_.emplace(); + string_pool_ = NewArenaStringPool(arena()); + type_pool_ = + NewTypePool(arena(), string_pool(), GetTestingDescriptorPool()); + } + + void TearDown() override { + type_pool_.reset(); + string_pool_.reset(); + arena_.reset(); + } + + absl::Nonnull arena() { return &*arena_; } + + absl::Nonnull string_pool() { return string_pool_.get(); } + + absl::Nonnull type_pool() { return type_pool_.get(); } + + private: + absl::optional arena_; + std::unique_ptr string_pool_; + std::unique_ptr type_pool_; +}; + +TEST_F(TypeProtoTest, Dyn) { + TypeProto expected_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb(dyn: {})pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DynType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Null) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(null: NULL_VALUE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(NullType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Bool) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: BOOL)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BoolType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Int) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: INT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(IntType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Uint) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: UINT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(UintType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Double) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: DOUBLE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DoubleType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, String) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: STRING)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(StringType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Bytes) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: BYTES)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BytesType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BoolWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: BOOL)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BoolWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, IntWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: INT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(IntWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, UintWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: UINT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(UintWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, DoubleWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: DOUBLE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DoubleWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, StringWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: STRING)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(StringWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BytesWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: BYTES)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BytesWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Any) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(well_known: ANY)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(AnyType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Duration) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(well_known: DURATION)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DurationType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Timestamp) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(well_known: TIMESTAMP)pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TimestampType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, List) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(list_type: { elem_type: { primitive: BOOL } })pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(ListType(arena(), BoolType()))); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Map) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(map_type: { + key_type: { primitive: INT64 } + value_type: { primitive: STRING } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(MapType(arena(), IntType(), StringType()))); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Function) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(function: { + result_type: { primitive: INT64 } + arg_types { primitive: STRING } + arg_types { primitive: INT64 } + arg_types { primitive: UINT64 } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(FunctionType(arena(), IntType(), + {StringType(), IntType(), UintType()}))); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Struct) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(message_type: "google.protobuf.Empty")pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT( + got, Eq(common_internal::MakeBasicStructType("google.protobuf.Empty"))); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BadStruct) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(message_type: "")pb", + &expected_proto)); + EXPECT_THAT(TypeFromProto(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(TypeProtoTest, TypeParam) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type_param: "T")pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TypeParamType("T"))); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BadTypeParam) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type_param: "")pb", + &expected_proto)); + EXPECT_THAT(TypeFromProto(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(TypeProtoTest, Type) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type: { dyn: {} })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TypeType(arena(), DynType()))); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Error) { + TypeProto expected_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb(error: {})pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(ErrorType())); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Opaque) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(abstract_type: { + name: "optional_type" + parameter_types { primitive: STRING } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(OptionalType(arena(), StringType()))); + TypeProto got_proto; + EXPECT_OK(TypeToProto(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BadOpaque) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(abstract_type: { name: "" })pb", &expected_proto)); + EXPECT_THAT(TypeFromProto(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel diff --git a/common/type_proto_v1alpha1.cc b/common/type_proto_v1alpha1.cc new file mode 100644 index 000000000..39bec96e1 --- /dev/null +++ b/common/type_proto_v1alpha1.cc @@ -0,0 +1,354 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto_v1alpha1.h" + +#include + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "common/type_pool.h" + +namespace cel { + +namespace { + +using TypeProto = ::google::api::expr::v1alpha1::Type; +using ListTypeProto = typename TypeProto::ListType; +using MapTypeProto = typename TypeProto::MapType; +using FunctionTypeProto = typename TypeProto::FunctionType; +using OpaqueTypeProto = typename TypeProto::AbstractType; +using PrimitiveTypeProto = typename TypeProto::PrimitiveType; +using WellKnownTypeProto = typename TypeProto::WellKnownType; + +struct TypeFromProtoConverter final { + explicit TypeFromProtoConverter(absl::Nonnull type_pool) + : type_pool(type_pool) {} + + absl::optional FromType(const TypeProto& proto) { + switch (proto.type_kind_case()) { + case TypeProto::TYPE_KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case TypeProto::kDyn: + return DynType(); + case TypeProto::kNull: + return NullType(); + case TypeProto::kPrimitive: + switch (proto.primitive()) { + case TypeProto::BOOL: + return BoolType(); + case TypeProto::INT64: + return IntType(); + case TypeProto::UINT64: + return UintType(); + case TypeProto::DOUBLE: + return DoubleType(); + case TypeProto::STRING: + return StringType(); + case TypeProto::BYTES: + return BytesType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected primitive type kind: ", proto.primitive())); + return absl::nullopt; + } + case TypeProto::kWrapper: + switch (proto.wrapper()) { + case TypeProto::BOOL: + return BoolWrapperType(); + case TypeProto::INT64: + return IntWrapperType(); + case TypeProto::UINT64: + return UintWrapperType(); + case TypeProto::DOUBLE: + return DoubleWrapperType(); + case TypeProto::STRING: + return StringWrapperType(); + case TypeProto::BYTES: + return BytesWrapperType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected wrapper type kind: ", proto.wrapper())); + return absl::nullopt; + } + case TypeProto::kWellKnown: + switch (proto.well_known()) { + case TypeProto::ANY: + return AnyType(); + case TypeProto::DURATION: + return DurationType(); + case TypeProto::TIMESTAMP: + return TimestampType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected well known type kind: ", proto.well_known())); + return absl::nullopt; + } + case TypeProto::kListType: { + auto elem = FromType(proto.list_type().elem_type()); + if (ABSL_PREDICT_FALSE(!elem.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeListType(*elem); + } + case TypeProto::kMapType: { + auto key = FromType(proto.map_type().key_type()); + if (ABSL_PREDICT_FALSE(!key.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + auto value = FromType(proto.map_type().value_type()); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeMapType(*key, *value); + } + case TypeProto::kFunction: { + auto result = FromType(proto.function().result_type()); + if (ABSL_PREDICT_FALSE(!result.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + absl::InlinedVector args; + args.reserve(static_cast(proto.function().arg_types().size())); + for (const auto& arg_proto : proto.function().arg_types()) { + auto arg = FromType(arg_proto); + if (ABSL_PREDICT_FALSE(!arg.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + args.push_back(*arg); + } + return type_pool->MakeFunctionType(*result, args); + } + case TypeProto::kMessageType: + if (ABSL_PREDICT_FALSE(proto.message_type().empty())) { + status = + absl::InvalidArgumentError("unexpected empty message type name"); + return absl::nullopt; + } + if (ABSL_PREDICT_FALSE(IsWellKnownMessageType(proto.message_type()))) { + status = absl::InvalidArgumentError( + absl::StrCat("well known type masquerading as message type: ", + proto.message_type())); + return absl::nullopt; + } + return type_pool->MakeStructType(proto.message_type()); + case TypeProto::kTypeParam: + if (ABSL_PREDICT_FALSE(proto.type_param().empty())) { + status = + absl::InvalidArgumentError("unexpected empty type param name"); + return absl::nullopt; + } + return type_pool->MakeTypeParamType(proto.type_param()); + case TypeProto::kType: { + auto type = FromType(proto.type()); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeTypeType(*type); + } + case TypeProto::kError: + return ErrorType(); + case TypeProto::kAbstractType: { + if (proto.abstract_type().name().empty()) { + status = + absl::InvalidArgumentError("unexpected empty opaque type name"); + return absl::nullopt; + } + absl::InlinedVector params; + params.reserve(static_cast( + proto.abstract_type().parameter_types().size())); + for (const auto& param_proto : + proto.abstract_type().parameter_types()) { + auto param = FromType(param_proto); + if (ABSL_PREDICT_FALSE(!param.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + params.push_back(*param); + } + return type_pool->MakeOpaqueType(proto.abstract_type().name(), params); + } + default: + status = absl::DataLossError(absl::StrCat("unexpected type kind case: ", + proto.type_kind_case())); + return absl::nullopt; + } + } + + absl::Nonnull const type_pool; + absl::Status status; +}; + +} // namespace + +absl::StatusOr TypeFromProtoV1Alpha1(absl::Nonnull type_pool, + const TypeProto& proto) { + TypeFromProtoConverter converter(type_pool); + auto type = converter.FromType(proto); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + ABSL_DCHECK(!converter.status.ok()); + return converter.status; + } + return *type; +} + +namespace { + +struct TypeToProtoConverter final { + bool FromType(const Type& type, absl::Nonnull proto) { + switch (type.kind()) { + case TypeKind::kDyn: + proto->mutable_dyn(); + return true; + case TypeKind::kNull: + proto->set_null(google::protobuf::NULL_VALUE); + return true; + case TypeKind::kBool: + proto->set_primitive(TypeProto::BOOL); + return true; + case TypeKind::kInt: + proto->set_primitive(TypeProto::INT64); + return true; + case TypeKind::kUint: + proto->set_primitive(TypeProto::UINT64); + return true; + case TypeKind::kDouble: + proto->set_primitive(TypeProto::DOUBLE); + return true; + case TypeKind::kBytes: + proto->set_primitive(TypeProto::BYTES); + return true; + case TypeKind::kString: + proto->set_primitive(TypeProto::STRING); + return true; + case TypeKind::kBoolWrapper: + proto->set_wrapper(TypeProto::BOOL); + return true; + case TypeKind::kIntWrapper: + proto->set_wrapper(TypeProto::INT64); + return true; + case TypeKind::kUintWrapper: + proto->set_wrapper(TypeProto::UINT64); + return true; + case TypeKind::kDoubleWrapper: + proto->set_wrapper(TypeProto::DOUBLE); + return true; + case TypeKind::kBytesWrapper: + proto->set_wrapper(TypeProto::BYTES); + return true; + case TypeKind::kStringWrapper: + proto->set_wrapper(TypeProto::STRING); + return true; + case TypeKind::kAny: + proto->set_well_known(TypeProto::ANY); + return true; + case TypeKind::kDuration: + proto->set_well_known(TypeProto::DURATION); + return true; + case TypeKind::kTimestamp: + proto->set_well_known(TypeProto::TIMESTAMP); + return true; + case TypeKind::kList: + return FromType(static_cast(type).GetElement(), + proto->mutable_list_type()->mutable_elem_type()); + case TypeKind::kMap: + return FromType(static_cast(type).GetKey(), + proto->mutable_map_type()->mutable_key_type()) && + FromType(static_cast(type).GetValue(), + proto->mutable_map_type()->mutable_value_type()); + case TypeKind::kStruct: + proto->set_message_type(static_cast(type).name()); + return true; + case TypeKind::kOpaque: { + auto opaque_type = static_cast(type); + auto* opaque_type_proto = proto->mutable_abstract_type(); + opaque_type_proto->set_name(opaque_type.name()); + auto opaque_type_params = opaque_type.GetParameters(); + opaque_type_proto->mutable_parameter_types()->Reserve( + static_cast(opaque_type_params.size())); + for (const auto& param : opaque_type_params) { + if (ABSL_PREDICT_FALSE( + !FromType(param, opaque_type_proto->add_parameter_types()))) { + return false; + } + } + return true; + } + case TypeKind::kTypeParam: + proto->set_type_param(static_cast(type).name()); + return true; + case TypeKind::kType: + return FromType(static_cast(type).GetType(), + proto->mutable_type()); + case TypeKind::kFunction: { + auto function_type = static_cast(type); + auto* function_type_proto = proto->mutable_function(); + if (ABSL_PREDICT_FALSE( + !FromType(function_type.result(), + function_type_proto->mutable_result_type()))) { + return false; + } + auto function_type_args = function_type.args(); + function_type_proto->mutable_arg_types()->Reserve( + static_cast(function_type_args.size())); + for (const auto& arg : function_type_args) { + if (ABSL_PREDICT_FALSE( + !FromType(arg, function_type_proto->add_arg_types()))) { + return false; + } + } + return true; + } + case TypeKind::kError: + proto->mutable_error(); + return true; + default: + status = absl::DataLossError( + absl::StrCat("unexpected type kind: ", type.kind())); + return false; + } + } + + absl::Status status; +}; + +} // namespace + +absl::Status TypeToProtoV1Alpha1(const Type& type, + absl::Nonnull proto) { + TypeToProtoConverter converter; + if (ABSL_PREDICT_FALSE(!converter.FromType(type, proto))) { + ABSL_DCHECK(!converter.status.ok()); + return converter.status; + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/type_proto_v1alpha1.h b/common/type_proto_v1alpha1.h new file mode 100644 index 000000000..b9cc92600 --- /dev/null +++ b/common/type_proto_v1alpha1.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_V1ALPHA1_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_V1ALPHA1_H_ + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/type.h" +#include "common/type_pool.h" + +namespace cel { + +// TypeFromProtoV1Alpha1 converts `google::api::expr::v1alpha1::Type` to +// `cel::Type`. +absl::StatusOr TypeFromProtoV1Alpha1( + absl::Nonnull type_pool, + const google::api::expr::v1alpha1::Type& proto); + +// TypeToProtoV1Alpha1 converts `cel::Type` to +// `google::api::expr::v1alpha1::Type`. +absl::Status TypeToProtoV1Alpha1( + const Type& type, absl::Nonnull proto); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_V1ALPHA1_H_ diff --git a/common/type_proto_v1alpha1_test.cc b/common/type_proto_v1alpha1_test.cc new file mode 100644 index 000000000..57ceb71f4 --- /dev/null +++ b/common/type_proto_v1alpha1_test.cc @@ -0,0 +1,412 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto_v1alpha1.h" + +#include + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/types/optional.h" +#include "common/arena_string_pool.h" +#include "common/type.h" +#include "common/type_pool.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::cel::internal::test::EqualsProto; +using testing::Eq; +using testing::Test; + +using TypeProto = ::google::api::expr::v1alpha1::Type; + +class TypeProtoV1Alpha1Test : public Test { + public: + void SetUp() override { + arena_.emplace(); + string_pool_ = NewArenaStringPool(arena()); + type_pool_ = + NewTypePool(arena(), string_pool(), GetTestingDescriptorPool()); + } + + void TearDown() override { + type_pool_.reset(); + string_pool_.reset(); + arena_.reset(); + } + + absl::Nonnull arena() { return &*arena_; } + + absl::Nonnull string_pool() { return string_pool_.get(); } + + absl::Nonnull type_pool() { return type_pool_.get(); } + + private: + absl::optional arena_; + std::unique_ptr string_pool_; + std::unique_ptr type_pool_; +}; + +TEST_F(TypeProtoV1Alpha1Test, Dyn) { + TypeProto expected_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb(dyn: {})pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DynType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Null) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(null: NULL_VALUE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(NullType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Bool) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: BOOL)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BoolType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Int) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: INT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(IntType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Uint) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: UINT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(UintType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Double) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: DOUBLE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DoubleType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, String) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: STRING)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(StringType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Bytes) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: BYTES)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BytesType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BoolWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: BOOL)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BoolWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, IntWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: INT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(IntWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, UintWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: UINT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(UintWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, DoubleWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: DOUBLE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DoubleWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, StringWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: STRING)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(StringWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BytesWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: BYTES)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BytesWrapperType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Any) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(well_known: ANY)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(AnyType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Duration) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(well_known: DURATION)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DurationType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Timestamp) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(well_known: TIMESTAMP)pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TimestampType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, List) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(list_type: { elem_type: { primitive: BOOL } })pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(ListType(arena(), BoolType()))); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Map) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(map_type: { + key_type: { primitive: INT64 } + value_type: { primitive: STRING } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(MapType(arena(), IntType(), StringType()))); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Function) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(function: { + result_type: { primitive: INT64 } + arg_types { primitive: STRING } + arg_types { primitive: INT64 } + arg_types { primitive: UINT64 } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(FunctionType(arena(), IntType(), + {StringType(), IntType(), UintType()}))); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Struct) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(message_type: "google.protobuf.Empty")pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT( + got, Eq(common_internal::MakeBasicStructType("google.protobuf.Empty"))); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BadStruct) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(message_type: "")pb", + &expected_proto)); + EXPECT_THAT(TypeFromProtoV1Alpha1(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(TypeProtoV1Alpha1Test, TypeParam) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type_param: "T")pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TypeParamType("T"))); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BadTypeParam) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type_param: "")pb", + &expected_proto)); + EXPECT_THAT(TypeFromProtoV1Alpha1(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(TypeProtoV1Alpha1Test, Type) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type: { dyn: {} })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TypeType(arena(), DynType()))); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Error) { + TypeProto expected_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb(error: {})pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(ErrorType())); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Opaque) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(abstract_type: { + name: "optional_type" + parameter_types { primitive: STRING } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(OptionalType(arena(), StringType()))); + TypeProto got_proto; + EXPECT_OK(TypeToProtoV1Alpha1(got, &got_proto)); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BadOpaque) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(abstract_type: { name: "" })pb", &expected_proto)); + EXPECT_THAT(TypeFromProtoV1Alpha1(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel diff --git a/common/types/map_type_pool.h b/common/types/map_type_pool.h index d86ddb2e9..29b21f154 100644 --- a/common/types/map_type_pool.h +++ b/common/types/map_type_pool.h @@ -18,7 +18,6 @@ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_POOL_H_ #include -#include #include #include "absl/base/nullability.h" @@ -41,15 +40,14 @@ class MapTypePool final { MapType InternMapType(const Type& key, const Type& value); private: - using MapTypeTuple = std::tuple, - std::reference_wrapper>; + using MapTypeTuple = std::tuple; static MapTypeTuple AsTuple(const MapType& map_type) { - return AsTuple(map_type.key(), map_type.value()); + return AsTuple(map_type.GetKey(), map_type.GetValue()); } static MapTypeTuple AsTuple(const Type& key, const Type& value) { - return MapTypeTuple{std::cref(key), std::cref(value)}; + return MapTypeTuple{key, value}; } struct Hasher { diff --git a/common/types/opaque_type_pool.h b/common/types/opaque_type_pool.h index 60b2b3c39..fe079febd 100644 --- a/common/types/opaque_type_pool.h +++ b/common/types/opaque_type_pool.h @@ -45,7 +45,7 @@ class OpaqueTypePool final { absl::Span parameters); private: - using OpaqueTypeTuple = std::tuple>; + using OpaqueTypeTuple = std::tuple; static OpaqueTypeTuple AsTuple(const OpaqueType& opaque_type) { return AsTuple(opaque_type.name(), opaque_type.GetParameters()); diff --git a/common/types/type_pool.cc b/common/types/type_pool.cc deleted file mode 100644 index fdbae2418..000000000 --- a/common/types/type_pool.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "common/types/type_pool.h" - -#include "absl/base/optimization.h" -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "common/type.h" - -namespace cel::common_internal { - -StructType TypePool::MakeStructType(absl::string_view name) { - ABSL_DCHECK(!IsWellKnownMessageType(name)) << name; - if (ABSL_PREDICT_FALSE(name.empty())) { - return StructType(); - } - if (const auto* descriptor = descriptors_->FindMessageTypeByName(name); - descriptor != nullptr) { - return MessageType(descriptor); - } - return MakeBasicStructType(InternString(name)); -} - -FunctionType TypePool::MakeFunctionType(const Type& result, - absl::Span args) { - absl::MutexLock lock(&functions_mutex_); - return functions_.InternFunctionType(result, args); -} - -ListType TypePool::MakeListType(const Type& element) { - if (element.IsDyn()) { - return ListType(); - } - absl::MutexLock lock(&lists_mutex_); - return lists_.InternListType(element); -} - -MapType TypePool::MakeMapType(const Type& key, const Type& value) { - if (key.IsDyn() && value.IsDyn()) { - return MapType(); - } - if (key.IsString() && value.IsDyn()) { - return JsonMapType(); - } - absl::MutexLock lock(&maps_mutex_); - return maps_.InternMapType(key, value); -} - -OpaqueType TypePool::MakeOpaqueType(absl::string_view name, - absl::Span parameters) { - if (name == OptionalType::kName) { - if (parameters.size() == 1 && parameters.front().IsDyn()) { - return OptionalType(); - } - name = OptionalType::kName; - } else { - name = InternString(name); - } - absl::MutexLock lock(&opaques_mutex_); - return opaques_.InternOpaqueType(name, parameters); -} - -OptionalType TypePool::MakeOptionalType(const Type& parameter) { - return static_cast( - MakeOpaqueType(OptionalType::kName, absl::MakeConstSpan(¶meter, 1))); -} - -TypeParamType TypePool::MakeTypeParamType(absl::string_view name) { - return TypeParamType(InternString(name)); -} - -TypeType TypePool::MakeTypeType(const Type& type) { - absl::MutexLock lock(&types_mutex_); - return types_.InternTypeType(type); -} - -absl::string_view TypePool::InternString(absl::string_view string) { - absl::MutexLock lock(&strings_mutex_); - return strings_.InternString(string); -} - -} // namespace cel::common_internal diff --git a/common/types/type_pool.h b/common/types/type_pool.h deleted file mode 100644 index 37f3ff662..000000000 --- a/common/types/type_pool.h +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ - -#include "absl/base/attributes.h" -#include "absl/base/nullability.h" -#include "absl/base/thread_annotations.h" -#include "absl/log/die_if_null.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "common/type.h" -#include "common/types/function_type_pool.h" -#include "common/types/list_type_pool.h" -#include "common/types/map_type_pool.h" -#include "common/types/opaque_type_pool.h" -#include "common/types/type_type_pool.h" -#include "internal/string_pool.h" -#include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" - -namespace cel::common_internal { - -// `TypePool` is a thread safe interning factory for complex types. All types -// are allocated using the provided `google::protobuf::Arena`. -class TypePool final { - public: - TypePool(absl::Nonnull descriptors - ABSL_ATTRIBUTE_LIFETIME_BOUND, - absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) - : descriptors_(ABSL_DIE_IF_NULL(descriptors)), // Crash OK - arena_(ABSL_DIE_IF_NULL(arena)), // Crash OK - strings_(arena_), - functions_(arena_), - lists_(arena_), - maps_(arena_), - opaques_(arena_), - types_(arena_) {} - - TypePool(const TypePool&) = delete; - TypePool(TypePool&&) = delete; - TypePool& operator=(const TypePool&) = delete; - TypePool& operator=(TypePool&&) = delete; - - StructType MakeStructType(absl::string_view name); - - FunctionType MakeFunctionType(const Type& result, - absl::Span args); - - ListType MakeListType(const Type& element); - - MapType MakeMapType(const Type& key, const Type& value); - - OpaqueType MakeOpaqueType(absl::string_view name, - absl::Span parameters); - - OptionalType MakeOptionalType(const Type& parameter); - - TypeParamType MakeTypeParamType(absl::string_view name); - - TypeType MakeTypeType(const Type& type); - - private: - absl::string_view InternString(absl::string_view string); - - absl::Nonnull const descriptors_; - absl::Nonnull const arena_; - absl::Mutex strings_mutex_; - internal::StringPool strings_ ABSL_GUARDED_BY(strings_mutex_); - absl::Mutex functions_mutex_; - FunctionTypePool functions_ ABSL_GUARDED_BY(functions_mutex_); - absl::Mutex lists_mutex_; - ListTypePool lists_ ABSL_GUARDED_BY(lists_mutex_); - absl::Mutex maps_mutex_; - MapTypePool maps_ ABSL_GUARDED_BY(maps_mutex_); - absl::Mutex opaques_mutex_; - OpaqueTypePool opaques_ ABSL_GUARDED_BY(opaques_mutex_); - absl::Mutex types_mutex_; - TypeTypePool types_ ABSL_GUARDED_BY(types_mutex_); -}; - -} // namespace cel::common_internal - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ diff --git a/common/types/type_pool_test.cc b/common/types/type_pool_test.cc deleted file mode 100644 index 6b079b2dd..000000000 --- a/common/types/type_pool_test.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "common/types/type_pool.h" - -#include "common/type.h" -#include "internal/testing.h" -#include "internal/testing_descriptor_pool.h" -#include "google/protobuf/arena.h" - -namespace cel::common_internal { -namespace { - -using ::cel::internal::GetTestingDescriptorPool; -using testing::_; - -TEST(TypePool, MakeStructType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeStructType("foo.Bar"), - MakeBasicStructType("foo.Bar")); - EXPECT_TRUE( - type_pool.MakeStructType("google.api.expr.test.v1.proto3.TestAllTypes") - .IsMessage()); - EXPECT_DEBUG_DEATH( - static_cast(type_pool.MakeStructType("google.protobuf.BoolValue")), - _); -} - -TEST(TypePool, MakeFunctionType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeFunctionType(BoolType(), {IntType(), IntType()}), - FunctionType(&arena, BoolType(), {IntType(), IntType()})); -} - -TEST(TypePool, MakeListType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeListType(DynType()), ListType()); - EXPECT_EQ(type_pool.MakeListType(DynType()), JsonListType()); - EXPECT_EQ(type_pool.MakeListType(StringType()), - ListType(&arena, StringType())); -} - -TEST(TypePool, MakeMapType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeMapType(DynType(), DynType()), MapType()); - EXPECT_EQ(type_pool.MakeMapType(StringType(), DynType()), JsonMapType()); - EXPECT_EQ(type_pool.MakeMapType(StringType(), StringType()), - MapType(&arena, StringType(), StringType())); -} - -TEST(TypePool, MakeOpaqueType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeOpaqueType("custom_type", {DynType(), DynType()}), - OpaqueType(&arena, "custom_type", {DynType(), DynType()})); -} - -TEST(TypePool, MakeOptionalType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeOptionalType(DynType()), OptionalType()); - EXPECT_EQ(type_pool.MakeOptionalType(StringType()), - OptionalType(&arena, StringType())); -} - -TEST(TypePool, MakeTypeParamType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeTypeParamType("T"), TypeParamType("T")); -} - -TEST(TypePool, MakeTypeType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeTypeType(BoolType()), TypeType(&arena, BoolType())); -} - -} // namespace -} // namespace cel::common_internal diff --git a/internal/BUILD b/internal/BUILD index 7c7c6b00b..d174b11cb 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -581,6 +581,7 @@ cel_proto_transitive_descriptor_set( name = "testing_descriptor_set", testonly = True, deps = [ + "@com_google_cel_spec//proto/cel/expr:expr_proto", "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_proto", "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_proto",