diff --git a/api/BUILD b/api/BUILD index 2caeb712b65e..1cf18ef314de 100644 --- a/api/BUILD +++ b/api/BUILD @@ -219,6 +219,7 @@ proto_library( "//envoy/extensions/internal_redirect/safe_cross_scheme/v3:pkg", "//envoy/extensions/key_value/file_based/v3:pkg", "//envoy/extensions/matching/common_inputs/environment_variable/v3:pkg", + "//envoy/extensions/matching/common_inputs/network/v3:pkg", "//envoy/extensions/matching/input_matchers/consistent_hashing/v3:pkg", "//envoy/extensions/matching/input_matchers/ip/v3:pkg", "//envoy/extensions/network/dns_resolver/apple/v3:pkg", diff --git a/api/envoy/extensions/matching/common_inputs/network/v3/BUILD b/api/envoy/extensions/matching/common_inputs/network/v3/BUILD new file mode 100644 index 000000000000..ee92fb652582 --- /dev/null +++ b/api/envoy/extensions/matching/common_inputs/network/v3/BUILD @@ -0,0 +1,9 @@ +# DO NOT EDIT. This file is generated by tools/proto_format/proto_sync.py. + +load("@envoy_api//bazel:api_build_system.bzl", "api_proto_package") + +licenses(["notice"]) # Apache 2 + +api_proto_package( + deps = ["@com_github_cncf_udpa//udpa/annotations:pkg"], +) diff --git a/api/envoy/extensions/matching/common_inputs/network/v3/network_inputs.proto b/api/envoy/extensions/matching/common_inputs/network/v3/network_inputs.proto new file mode 100644 index 000000000000..8f54d3458702 --- /dev/null +++ b/api/envoy/extensions/matching/common_inputs/network/v3/network_inputs.proto @@ -0,0 +1,88 @@ +syntax = "proto3"; + +package envoy.extensions.matching.common_inputs.network.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.matching.common_inputs.network.v3"; +option java_outer_classname = "NetworkInputsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/matching/common_inputs/network/v3;networkv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Common Network Matching Inputs] + +// Specifies that matching should be performed by the destination IP address. +message DestinationIPInput { +} + +// Specifies that matching should be performed by the destination port. +message DestinationPortInput { +} + +// Specifies that matching should be performed by the source IP address. +message SourceIPInput { +} + +// Specifies that matching should be performed by the source port. +message SourcePortInput { +} + +// Input that matches by the directly connected source IP address (this +// will only be different from the source IP address when using a listener +// filter that overrides the source address, such as the :ref:`Proxy Protocol +// listener filter `). +message DirectSourceIPInput { +} + +// Input that matches by the source IP type. +// Specifies the source IP match type. The values include: +// +// * ``local`` - matches a connection originating from the same host, +message SourceTypeInput { +} + +// Input that matches by the requested server name (e.g. SNI in TLS). +// +// :ref:`TLS Inspector ` provides the requested server name based on SNI, +// when TLS protocol is detected. +message ServerNameInput { +} + +// Input that matches by the transport protocol. +// +// Suggested values include: +// +// * ``raw_buffer`` - default, used when no transport protocol is detected, +// * ``tls`` - set by :ref:`envoy.filters.listener.tls_inspector ` +// when TLS protocol is detected. +message TransportProtocolInput { +} + +// List of quoted and comma-separated requested application protocols. The list consists of a +// single negotiated application protocol once the network stream is established. +// +// Examples: +// +// * ``'h2','http/1.1'`` +// * ``'h2c'``` +// +// Suggested values in the list include: +// +// * ``http/1.1`` - set by :ref:`envoy.filters.listener.tls_inspector +// ` and :ref:`envoy.filters.listener.http_inspector +// `, +// * ``h2`` - set by :ref:`envoy.filters.listener.tls_inspector ` +// * ``h2c`` - set by :ref:`envoy.filters.listener.http_inspector ` +// +// .. attention:: +// +// Currently, :ref:`TLS Inspector ` provides +// application protocol detection based on the requested +// `ALPN `_ values. +// +// However, the use of ALPN is pretty much limited to the HTTP/2 traffic on the Internet, +// and matching on values other than ``h2`` is going to lead to a lot of false negatives, +// unless all connecting clients are known to use ALPN. +message ApplicationProtocolInput { +} diff --git a/api/versioning/BUILD b/api/versioning/BUILD index 339713969759..ec1462ad2d57 100644 --- a/api/versioning/BUILD +++ b/api/versioning/BUILD @@ -159,6 +159,7 @@ proto_library( "//envoy/extensions/load_balancing_policies/round_robin/v3:pkg", "//envoy/extensions/load_balancing_policies/wrr_locality/v3:pkg", "//envoy/extensions/matching/common_inputs/environment_variable/v3:pkg", + "//envoy/extensions/matching/common_inputs/network/v3:pkg", "//envoy/extensions/matching/input_matchers/consistent_hashing/v3:pkg", "//envoy/extensions/matching/input_matchers/ip/v3:pkg", "//envoy/extensions/network/dns_resolver/apple/v3:pkg", diff --git a/docs/root/api-v3/common_messages/common_messages.rst b/docs/root/api-v3/common_messages/common_messages.rst index d14a59db966c..8462681ee1bf 100644 --- a/docs/root/api-v3/common_messages/common_messages.rst +++ b/docs/root/api-v3/common_messages/common_messages.rst @@ -31,3 +31,4 @@ Common messages ../extensions/matching/input_matchers/consistent_hashing/v3/consistent_hashing.proto ../extensions/matching/input_matchers/ip/v3/ip.proto ../extensions/matching/common_inputs/environment_variable/v3/input.proto + ../extensions/matching/common_inputs/network/v3/network_inputs.proto diff --git a/envoy/network/filter.h b/envoy/network/filter.h index 16a5de4c2e3e..a41a76cad45a 100644 --- a/envoy/network/filter.h +++ b/envoy/network/filter.h @@ -525,6 +525,10 @@ class FilterChainFactory { class MatchingData { public: static absl::string_view name() { return "network"; } + + virtual ~MatchingData() = default; + + virtual const ConnectionSocket& socket() const PURE; }; } // namespace Network diff --git a/source/common/network/matching/BUILD b/source/common/network/matching/BUILD new file mode 100644 index 000000000000..dce37cb6a2dc --- /dev/null +++ b/source/common/network/matching/BUILD @@ -0,0 +1,30 @@ +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +licenses(["notice"]) # Apache 2 + +envoy_package() + +envoy_cc_library( + name = "data_impl_lib", + hdrs = ["data_impl.h"], + deps = [ + "//envoy/network:filter_interface", + ], +) + +envoy_cc_library( + name = "inputs_lib", + srcs = ["inputs.cc"], + hdrs = ["inputs.h"], + deps = [ + "//envoy/matcher:matcher_interface", + "//envoy/network:filter_interface", + "//envoy/registry", + "//source/common/network:utility_lib", + "@envoy_api//envoy/extensions/matching/common_inputs/network/v3:pkg_cc_proto", + ], +) diff --git a/source/common/network/matching/data_impl.h b/source/common/network/matching/data_impl.h new file mode 100644 index 000000000000..c8fd65b29f26 --- /dev/null +++ b/source/common/network/matching/data_impl.h @@ -0,0 +1,24 @@ +#pragma once + +#include "envoy/network/filter.h" + +namespace Envoy { +namespace Network { +namespace Matching { + +/** + * Implementation of Network::MatchingData, providing connection-level data to + * the match tree. + */ +class MatchingDataImpl : public MatchingData { +public: + explicit MatchingDataImpl(const ConnectionSocket& socket) : socket_(socket) {} + const ConnectionSocket& socket() const override { return socket_; } + +private: + const ConnectionSocket& socket_; +}; + +} // namespace Matching +} // namespace Network +} // namespace Envoy diff --git a/source/common/network/matching/inputs.cc b/source/common/network/matching/inputs.cc new file mode 100644 index 000000000000..deb64a6b4035 --- /dev/null +++ b/source/common/network/matching/inputs.cc @@ -0,0 +1,105 @@ +#include "source/common/network/matching/inputs.h" + +#include "envoy/registry/registry.h" + +#include "source/common/network/utility.h" + +#include "absl/strings/str_cat.h" + +namespace Envoy { +namespace Network { +namespace Matching { + +Matcher::DataInputGetResult DestinationIPInput::get(const MatchingData& data) const { + const auto& address = data.socket().connectionInfoProvider().localAddress(); + if (address->type() != Network::Address::Type::Ip) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, + address->ip()->addressAsString()}; +} + +Matcher::DataInputGetResult DestinationPortInput::get(const MatchingData& data) const { + const auto& address = data.socket().connectionInfoProvider().localAddress(); + if (address->type() != Network::Address::Type::Ip) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, + absl::StrCat(address->ip()->port())}; +} + +Matcher::DataInputGetResult SourceIPInput::get(const MatchingData& data) const { + const auto& address = data.socket().connectionInfoProvider().remoteAddress(); + if (address->type() != Network::Address::Type::Ip) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, + address->ip()->addressAsString()}; +} + +Matcher::DataInputGetResult SourcePortInput::get(const MatchingData& data) const { + const auto& address = data.socket().connectionInfoProvider().remoteAddress(); + if (address->type() != Network::Address::Type::Ip) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, + absl::StrCat(address->ip()->port())}; +} + +Matcher::DataInputGetResult DirectSourceIPInput::get(const MatchingData& data) const { + const auto& address = data.socket().connectionInfoProvider().directRemoteAddress(); + if (address->type() != Network::Address::Type::Ip) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, + address->ip()->addressAsString()}; +} + +Matcher::DataInputGetResult SourceTypeInput::get(const MatchingData& data) const { + const bool is_local_connection = Network::Utility::isSameIpOrLoopback(data.socket()); + if (is_local_connection) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, "local"}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; +} + +Matcher::DataInputGetResult ServerNameInput::get(const MatchingData& data) const { + const auto server_name = data.socket().requestedServerName(); + if (!server_name.empty()) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, + std::string(server_name)}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; +} + +Matcher::DataInputGetResult TransportProtocolInput::get(const MatchingData& data) const { + const auto transport_protocol = data.socket().detectedTransportProtocol(); + if (!transport_protocol.empty()) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, + std::string(transport_protocol)}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; +} + +Matcher::DataInputGetResult ApplicationProtocolInput::get(const MatchingData& data) const { + const auto& protocols = data.socket().requestedApplicationProtocols(); + if (!protocols.empty()) { + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, + absl::StrCat("'", absl::StrJoin(protocols, "','"), "'")}; + } + return {Matcher::DataInputGetResult::DataAvailability::AllDataAvailable, absl::nullopt}; +} + +REGISTER_FACTORY(DestinationIPInputFactory, Matcher::DataInputFactory); +REGISTER_FACTORY(DestinationPortInputFactory, Matcher::DataInputFactory); +REGISTER_FACTORY(SourceIPInputFactory, Matcher::DataInputFactory); +REGISTER_FACTORY(SourcePortInputFactory, Matcher::DataInputFactory); +REGISTER_FACTORY(DirectSourceIPInputFactory, Matcher::DataInputFactory); +REGISTER_FACTORY(SourceTypeInputFactory, Matcher::DataInputFactory); +REGISTER_FACTORY(ServerNameInputFactory, Matcher::DataInputFactory); +REGISTER_FACTORY(TransportProtocolInputFactory, Matcher::DataInputFactory); +REGISTER_FACTORY(ApplicationProtocolInputFactory, Matcher::DataInputFactory); + +} // namespace Matching +} // namespace Network +} // namespace Envoy diff --git a/source/common/network/matching/inputs.h b/source/common/network/matching/inputs.h new file mode 100644 index 000000000000..20bb872e36ed --- /dev/null +++ b/source/common/network/matching/inputs.h @@ -0,0 +1,147 @@ +#pragma once + +#include "envoy/extensions/matching/common_inputs/network/v3/network_inputs.pb.h" +#include "envoy/extensions/matching/common_inputs/network/v3/network_inputs.pb.validate.h" +#include "envoy/matcher/matcher.h" +#include "envoy/network/filter.h" + +namespace Envoy { +namespace Network { +namespace Matching { + +template +class BaseFactory : public Matcher::DataInputFactory { +protected: + explicit BaseFactory(const std::string& name) : name_(name) {} + +public: + std::string name() const override { return name_; } + + Matcher::DataInputFactoryCb + createDataInputFactoryCb(const Protobuf::Message&, ProtobufMessage::ValidationVisitor&) override { + return []() { return std::make_unique(); }; + }; + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + +private: + const std::string name_; +}; + +class DestinationIPInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class DestinationIPInputFactory + : public BaseFactory< + DestinationIPInput, + envoy::extensions::matching::common_inputs::network::v3::DestinationIPInput> { +public: + DestinationIPInputFactory() : BaseFactory("destination-ip") {} +}; + +class DestinationPortInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class DestinationPortInputFactory + : public BaseFactory< + DestinationPortInput, + envoy::extensions::matching::common_inputs::network::v3::DestinationPortInput> { +public: + DestinationPortInputFactory() : BaseFactory("destination-port") {} +}; + +class SourceIPInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class SourceIPInputFactory + : public BaseFactory { +public: + SourceIPInputFactory() : BaseFactory("source-ip") {} +}; + +class SourcePortInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class SourcePortInputFactory + : public BaseFactory { +public: + SourcePortInputFactory() : BaseFactory("source-port") {} +}; + +class DirectSourceIPInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class DirectSourceIPInputFactory + : public BaseFactory< + DirectSourceIPInput, + envoy::extensions::matching::common_inputs::network::v3::DirectSourceIPInput> { +public: + DirectSourceIPInputFactory() : BaseFactory("direct-source-ip") {} +}; + +class SourceTypeInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class SourceTypeInputFactory + : public BaseFactory { +public: + SourceTypeInputFactory() : BaseFactory("source-type") {} +}; + +class ServerNameInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class ServerNameInputFactory + : public BaseFactory { +public: + ServerNameInputFactory() : BaseFactory("server-name") {} +}; + +class TransportProtocolInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class TransportProtocolInputFactory + : public BaseFactory< + TransportProtocolInput, + envoy::extensions::matching::common_inputs::network::v3::TransportProtocolInput> { +public: + TransportProtocolInputFactory() : BaseFactory("transport-protocol") {} +}; + +class ApplicationProtocolInput : public Matcher::DataInput { +public: + Matcher::DataInputGetResult get(const MatchingData& data) const override; +}; + +class ApplicationProtocolInputFactory + : public BaseFactory< + ApplicationProtocolInput, + envoy::extensions::matching::common_inputs::network::v3::ApplicationProtocolInput> { +public: + ApplicationProtocolInputFactory() : BaseFactory("application-protocol") {} +}; + +} // namespace Matching +} // namespace Network +} // namespace Envoy diff --git a/test/common/network/matching/BUILD b/test/common/network/matching/BUILD new file mode 100644 index 000000000000..e9f3947bdab0 --- /dev/null +++ b/test/common/network/matching/BUILD @@ -0,0 +1,20 @@ +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_test", + "envoy_package", +) + +licenses(["notice"]) # Apache 2 + +envoy_package() + +envoy_cc_test( + name = "inputs_test", + srcs = ["inputs_test.cc"], + deps = [ + "//source/common/network:address_lib", + "//source/common/network/matching:data_impl_lib", + "//source/common/network/matching:inputs_lib", + "//test/mocks/network:network_mocks", + ], +) diff --git a/test/common/network/matching/inputs_test.cc b/test/common/network/matching/inputs_test.cc new file mode 100644 index 000000000000..45ab24cfdbfe --- /dev/null +++ b/test/common/network/matching/inputs_test.cc @@ -0,0 +1,238 @@ +#include "envoy/http/filter.h" + +#include "source/common/network/address_impl.h" +#include "source/common/network/matching/data_impl.h" +#include "source/common/network/matching/inputs.h" + +#include "test/mocks/network/mocks.h" + +namespace Envoy { +namespace Network { +namespace Matching { + +TEST(MatchingData, DestinationIPInput) { + DestinationIPInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + socket.connection_info_provider_->setLocalAddress( + std::make_shared("127.0.0.1", 8080)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, "127.0.0.1"); + } + + { + socket.connection_info_provider_->setLocalAddress( + std::make_shared("/pipe/path")); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } +} + +TEST(MatchingData, DestinationPortInput) { + DestinationPortInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + socket.connection_info_provider_->setLocalAddress( + std::make_shared("127.0.0.1", 8080)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, "8080"); + } + + { + socket.connection_info_provider_->setLocalAddress( + std::make_shared("/pipe/path")); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } +} + +TEST(MatchingData, SourceIPInput) { + SourceIPInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + socket.connection_info_provider_->setRemoteAddress( + std::make_shared("127.0.0.1", 8080)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, "127.0.0.1"); + } + + { + socket.connection_info_provider_->setRemoteAddress( + std::make_shared("/pipe/path")); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } +} + +TEST(MatchingData, SourcePortInput) { + SourcePortInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + socket.connection_info_provider_->setRemoteAddress( + std::make_shared("127.0.0.1", 8080)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, "8080"); + } + + { + socket.connection_info_provider_->setRemoteAddress( + std::make_shared("/pipe/path")); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } +} + +TEST(MatchingData, DirectSourceIPInput) { + DirectSourceIPInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + socket.connection_info_provider_->setDirectRemoteAddressForTest( + std::make_shared("127.0.0.1", 8080)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, "127.0.0.1"); + } + + { + socket.connection_info_provider_->setDirectRemoteAddressForTest( + std::make_shared("/pipe/path")); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } +} + +TEST(MatchingData, SourceTypeInput) { + SourceTypeInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + socket.connection_info_provider_->setRemoteAddress( + std::make_shared("127.0.0.1", 8080)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, "local"); + } + + { + socket.connection_info_provider_->setRemoteAddress( + std::make_shared("10.0.0.1")); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } +} + +TEST(MatchingData, ServerNameInput) { + ServerNameInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + EXPECT_CALL(socket, requestedServerName).WillOnce(testing::Return("")); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } + + { + const auto host = "example.com"; + EXPECT_CALL(socket, requestedServerName).WillOnce(testing::Return(host)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, host); + } +} + +TEST(MatchingData, TransportProtocolInput) { + TransportProtocolInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + EXPECT_CALL(socket, detectedTransportProtocol).WillOnce(testing::Return("")); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } + + { + const auto protocol = "tls"; + EXPECT_CALL(socket, detectedTransportProtocol).WillOnce(testing::Return(protocol)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, protocol); + } +} + +TEST(MatchingData, ApplicationProtocolInput) { + ApplicationProtocolInput input; + MockConnectionSocket socket; + MatchingDataImpl data(socket); + + { + std::vector protocols = {}; + EXPECT_CALL(socket, requestedApplicationProtocols).WillOnce(testing::ReturnRef(protocols)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, absl::nullopt); + } + + { + std::vector protocols = {"h2c"}; + EXPECT_CALL(socket, requestedApplicationProtocols).WillOnce(testing::ReturnRef(protocols)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, "'h2c'"); + } + + { + std::vector protocols = {"h2", "http/1.1"}; + EXPECT_CALL(socket, requestedApplicationProtocols).WillOnce(testing::ReturnRef(protocols)); + const auto result = input.get(data); + EXPECT_EQ(result.data_availability_, + Matcher::DataInputGetResult::DataAvailability::AllDataAvailable); + EXPECT_EQ(result.data_, "'h2','http/1.1'"); + } +} + +} // namespace Matching +} // namespace Network +} // namespace Envoy diff --git a/test/extensions/common/matcher/BUILD b/test/extensions/common/matcher/BUILD index e3724745b144..a6a0bf0d0a5a 100644 --- a/test/extensions/common/matcher/BUILD +++ b/test/extensions/common/matcher/BUILD @@ -23,10 +23,14 @@ envoy_cc_test( srcs = ["trie_matcher_test.cc"], deps = [ "//source/common/matcher:matcher_lib", + "//source/common/network:address_lib", + "//source/common/network/matching:data_impl_lib", + "//source/common/network/matching:inputs_lib", "//source/extensions/common/matcher:trie_matcher_lib", "//test/common/matcher:test_utility_lib", "//test/mocks/http:http_mocks", "//test/mocks/matcher:matcher_mocks", + "//test/mocks/network:network_mocks", "//test/mocks/server:factory_context_mocks", "//test/mocks/stream_info:stream_info_mocks", "//test/test_common:registry_lib", diff --git a/test/extensions/common/matcher/trie_matcher_test.cc b/test/extensions/common/matcher/trie_matcher_test.cc index d5c66ed71bb0..ddf41211576f 100644 --- a/test/extensions/common/matcher/trie_matcher_test.cc +++ b/test/extensions/common/matcher/trie_matcher_test.cc @@ -5,11 +5,14 @@ #include "envoy/registry/registry.h" #include "source/common/matcher/matcher.h" +#include "source/common/network/address_impl.h" +#include "source/common/network/matching/data_impl.h" #include "source/common/protobuf/utility.h" #include "source/extensions/common/matcher/trie_matcher.h" #include "test/common/matcher/test_utility.h" #include "test/mocks/matcher/mocks.h" +#include "test/mocks/network/mocks.h" #include "test/mocks/server/factory_context.h" #include "test/test_common/registry.h" #include "test/test_common/utility.h" @@ -441,6 +444,51 @@ TEST_F(TrieMatcherTest, NoData) { } } +TEST(TrieMatcherIntegrationTest, NetworkMatchingData) { + const std::string yaml = R"EOF( +matcher_tree: + input: + name: input + typed_config: + "@type": type.googleapis.com/envoy.extensions.matching.common_inputs.network.v3.DestinationIPInput + custom_match: + name: ip_matcher + typed_config: + "@type": type.googleapis.com/xds.type.matcher.v3.IPMatcher + range_matchers: + - ranges: + - address_prefix: 192.0.0.0 + prefix_len: 2 + on_match: + action: + name: test_action + typed_config: + "@type": type.googleapis.com/google.protobuf.StringValue + value: foo + )EOF"; + xds::type::matcher::v3::Matcher matcher; + MessageUtil::loadFromYaml(yaml, matcher, ProtobufMessage::getStrictValidationVisitor()); + + StringActionFactory action_factory; + Registry::InjectFactory> inject_action(action_factory); + NiceMock factory_context; + MockMatchTreeValidationVisitor validation_visitor; + EXPECT_CALL(validation_visitor, performDataInputValidation(_, _)).Times(testing::AnyNumber()); + absl::string_view context = ""; + MatchTreeFactory matcher_factory( + context, factory_context, validation_visitor); + auto match_tree = matcher_factory.create(matcher); + + Network::MockConnectionSocket socket; + socket.connection_info_provider_->setLocalAddress( + std::make_shared("192.168.0.1", 8080)); + Network::Matching::MatchingDataImpl data(socket); + + const auto result = match_tree()->match(data); + EXPECT_EQ(result.match_state_, MatchState::MatchComplete); + EXPECT_EQ(result.on_match_->action_cb_()->getTyped().string_, "foo"); +} + } // namespace } // namespace Matcher } // namespace Common