diff --git a/configs/envoy_service_to_service.template.json b/configs/envoy_service_to_service.template.json index fcbc4ae0f993..674e77ad52a6 100644 --- a/configs/envoy_service_to_service.template.json +++ b/configs/envoy_service_to_service.template.json @@ -245,8 +245,14 @@ "type": "read", "name": "tcp_proxy", "config": { - "cluster": "mongo_{{ key }}", - "stat_prefix": "mongo_{{ key }}" + "stat_prefix": "mongo_{{ key }}", + "route_config": { + "routes": [ + { + "cluster": "mongo_{{ key }}" + } + ] + } } }] }{% if not loop.last %},{% endif -%} diff --git a/docs/configuration/network_filters/tcp_proxy_filter.rst b/docs/configuration/network_filters/tcp_proxy_filter.rst index 39eb2cdc8e71..58d4f146322f 100644 --- a/docs/configuration/network_filters/tcp_proxy_filter.rst +++ b/docs/configuration/network_filters/tcp_proxy_filter.rst @@ -11,14 +11,13 @@ TCP proxy :ref:`architecture overview `. "type": "read", "name": "tcp_proxy", "config": { - "cluster": "...", - "stat_prefix": "..." + "stat_prefix": "...", + "route_config": "{...}" } } -cluster - *(required, string)* The :ref:`cluster manager ` cluster to connect - to when a new downstream network connection is received. +:ref:`route_config ` + *(required, object)* The route table for the filter. All filter instances must have a route table, even if it is empty. stat_prefix *(required, string)* The prefix to use when emitting :ref:`statistics @@ -39,3 +38,8 @@ statistics are rooted at *tcp..* with the following statistics: downstream_cx_tx_bytes_total, Counter, Total bytes written to the downstream connection. downstream_cx_tx_bytes_buffered, Gauge, Total bytes currently buffered to the downstream connection. + +.. toctree:: + :hidden: + + tcp_proxy_filter_route_config diff --git a/docs/configuration/network_filters/tcp_proxy_filter_route_config.rst b/docs/configuration/network_filters/tcp_proxy_filter_route_config.rst new file mode 100644 index 000000000000..b5d3ce77709b --- /dev/null +++ b/docs/configuration/network_filters/tcp_proxy_filter_route_config.rst @@ -0,0 +1,21 @@ +.. _config_network_filters_tcp_proxy_route_config: + +Route Configuration +=================== + +* TCP proxy :ref:`architecture overview `. +* TCP proxy :ref:`filter `. + +.. code-block:: json + + { + "routes": [] + } + +:ref:`routes ` + *(required, array)* An array of route entries that make up the route table. + +.. toctree:: + :hidden: + + tcp_proxy_filter_route diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index 50bdc0b97f6a..78a938712417 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -115,6 +115,25 @@ class Connection : public Event::DeferredDeletable, public FilterManager { */ virtual const std::string& remoteAddress() PURE; + /** + * @return The port number used by the remote client. + */ + virtual uint32_t remotePort() PURE; + + /** + * @return The address the remote client is trying to connect to. + * It can be different from the proxy address if the downstream connection + * has been redirected or the proxy is operating in transparent mode. + */ + virtual const std::string destinationAddress() PURE; + + /** + * @return The port number the remote client is trying to connect to. + * It can be different from the port the listener is listening on if the connection + * has been redirected or the proxy is operating in transparent mode. + */ + virtual uint32_t destinationPort() PURE; + /** * Set the buffer stats to update when the connection's read/write buffers change. Note that * for performance reasons these stats are eventually consistent and may not always accurately diff --git a/source/common/filter/tcp_proxy.cc b/source/common/filter/tcp_proxy.cc index 96c7cd1656eb..f8156fecee9f 100644 --- a/source/common/filter/tcp_proxy.cc +++ b/source/common/filter/tcp_proxy.cc @@ -8,19 +8,84 @@ #include "envoy/upstream/upstream.h" #include "common/common/assert.h" +#include "common/common/empty_string.h" #include "common/json/json_loader.h" namespace Filter { +TcpProxyConfig::Route::Route(const Json::Object& config) { + if (config.hasObject("cluster")) { + cluster_name_ = config.getString("cluster"); + } else { + throw EnvoyException(fmt::format("tcp proxy: route without cluster")); + } + + if (config.hasObject("source_ip_list")) { + source_ips_ = Network::IpList(config.getStringArray("source_ip_list")); + } + + if (config.hasObject("source_ports")) { + Network::Utility::parsePortRangeList(config.getString("source_ports"), source_port_ranges_); + } + + if (config.hasObject("destination_ip_list")) { + destination_ips_ = Network::IpList(config.getStringArray("destination_ip_list")); + } + + if (config.hasObject("destination_ports")) { + Network::Utility::parsePortRangeList(config.getString("destination_ports"), + destination_port_ranges_); + } +} + TcpProxyConfig::TcpProxyConfig(const Json::Object& config, Upstream::ClusterManager& cluster_manager, Stats::Store& stats_store) - : cluster_name_(config.getString("cluster")), - stats_(generateStats(config.getString("stat_prefix"), stats_store)) { - if (!cluster_manager.get(cluster_name_)) { - throw EnvoyException(fmt::format("tcp proxy: unknown cluster '{}'", cluster_name_)); + : stats_(generateStats(config.getString("stat_prefix"), stats_store)) { + if (!config.hasObject("route_config")) { + throw EnvoyException(fmt::format("tcp proxy: missing route config")); + } + + for (const Json::ObjectPtr& route_desc : + config.getObject("route_config")->getObjectArray("routes")) { + routes_.emplace_back(Route(*route_desc)); + + if (!cluster_manager.get(route_desc->getString("cluster"))) { + throw EnvoyException(fmt::format("tcp proxy: unknown cluster '{}' in TCP route", + route_desc->getString("cluster"))); + } } } +const std::string& TcpProxyConfig::getClusterForConnection(Network::Connection& connection) { + for (const TcpProxyConfig::Route& route : routes_) { + if (!route.source_port_ranges_.empty() && + !Network::Utility::portInRangeList(connection.remotePort(), route.source_port_ranges_)) { + continue; // no match, try next route + } + + if (!route.source_ips_.empty() && !route.source_ips_.contains(connection.remoteAddress())) { + continue; // no match, try next route + } + + if (!route.destination_port_ranges_.empty() && + !Network::Utility::portInRangeList(connection.destinationPort(), + route.destination_port_ranges_)) { + continue; // no match, try next route + } + + if (!route.destination_ips_.empty() && + !route.destination_ips_.contains(connection.destinationAddress())) { + continue; // no match, try next route + } + + // if we made it past all checks, the route matches + return route.cluster_name_; + } + + // no match, no more routes to try + return EMPTY_STRING; +} + TcpProxy::TcpProxy(TcpProxyConfigPtr config, Upstream::ClusterManager& cluster_manager) : config_(config), cluster_manager_(cluster_manager), downstream_callbacks_(*this), upstream_callbacks_(new UpstreamCallbacks(*this)) {} @@ -56,14 +121,27 @@ void TcpProxy::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callb } Network::FilterStatus TcpProxy::initializeUpstreamConnection() { - Upstream::ClusterInfoPtr cluster = cluster_manager_.get(config_->clusterName()); + const std::string& destination_cluster = + config_->getClusterForConnection(read_callbacks_->connection()); + conn_log_debug("Connection from {}", read_callbacks_->connection(), destination_cluster); + + Upstream::ClusterInfoPtr cluster = cluster_manager_.get(destination_cluster); + if (cluster) { + conn_log_debug("Connection cluster with name {} found", read_callbacks_->connection(), + destination_cluster); + } else { + conn_log_debug("Connection cluster with name {} NOT FOUND", read_callbacks_->connection(), + destination_cluster); + return Network::FilterStatus::StopIteration; + } + if (!cluster->resourceManager(Upstream::ResourcePriority::Default).connections().canCreate()) { cluster->stats().upstream_cx_overflow_.inc(); read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); return Network::FilterStatus::StopIteration; } Upstream::Host::CreateConnectionData conn_info = - cluster_manager_.tcpConnForCluster(config_->clusterName()); + cluster_manager_.tcpConnForCluster(destination_cluster); upstream_connection_ = std::move(conn_info.connection_); read_callbacks_->upstreamHost(conn_info.host_description_); diff --git a/source/common/filter/tcp_proxy.h b/source/common/filter/tcp_proxy.h index 898686076983..9855c6c35959 100644 --- a/source/common/filter/tcp_proxy.h +++ b/source/common/filter/tcp_proxy.h @@ -9,6 +9,7 @@ #include "common/common/logger.h" #include "common/json/json_loader.h" #include "common/network/filter_impl.h" +#include "common/network/utility.h" namespace Filter { @@ -38,13 +39,24 @@ class TcpProxyConfig { TcpProxyConfig(const Json::Object& config, Upstream::ClusterManager& cluster_manager, Stats::Store& stats_store); - const std::string& clusterName() { return cluster_name_; } + const std::string& getClusterForConnection(Network::Connection& connection); + const TcpProxyStats& stats() { return stats_; } private: + struct Route { + Route(const Json::Object& config); + + Network::IpList source_ips_; + Network::PortRangeList source_port_ranges_; + Network::IpList destination_ips_; + Network::PortRangeList destination_port_ranges_; + std::string cluster_name_; + }; + static TcpProxyStats generateStats(const std::string& name, Stats::Store& store); - std::string cluster_name_; + std::list routes_; const TcpProxyStats stats_; }; diff --git a/source/common/network/connection_impl.cc b/source/common/network/connection_impl.cc index 1e35f7fa98f0..70254b4d7ae4 100644 --- a/source/common/network/connection_impl.cc +++ b/source/common/network/connection_impl.cc @@ -1,4 +1,5 @@ #include "connection_impl.h" +#include "utility.h" #include "envoy/event/timer.h" #include "envoy/common/exception.h" @@ -33,9 +34,9 @@ void ConnectionImplUtility::updateBufferStats(uint64_t delta, uint64_t new_total std::atomic ConnectionImpl::next_global_id_; ConnectionImpl::ConnectionImpl(Event::DispatcherImpl& dispatcher, int fd, - const std::string& remote_address) - : filter_manager_(*this, *this), remote_address_(remote_address), dispatcher_(dispatcher), - fd_(fd), id_(++next_global_id_) { + const std::string& remote_address, uint32_t remote_port) + : filter_manager_(*this, *this), remote_address_(remote_address), remote_port_(remote_port), + dispatcher_(dispatcher), fd_(fd), id_(++next_global_id_) { // Treat the lack of a valid fd (which in practice only happens if we run out of FDs) as an OOM // condition and just crash. @@ -395,9 +396,35 @@ void ConnectionImpl::updateWriteBufferStats(uint64_t num_written, uint64_t new_s buffer_stats_->write_current_); } +const std::string Network::ConnectionImpl::destinationAddress() { + if (fd_ != -1) { + sockaddr_storage orig_dst_addr; + memset(&orig_dst_addr, 0, sizeof(orig_dst_addr)); + bool success = Utility::getOriginalDst(fd_, &orig_dst_addr); + if (success) { + return Utility::getAddressName(reinterpret_cast(&orig_dst_addr)); + } + } + + return EMPTY_STRING; +} + +uint32_t Network::ConnectionImpl::destinationPort() { + if (fd_ != -1) { + sockaddr_storage orig_dst_addr; + memset(&orig_dst_addr, 0, sizeof(orig_dst_addr)); + bool success = Utility::getOriginalDst(fd_, &orig_dst_addr); + if (success) { + return Utility::getAddressPort(reinterpret_cast(&orig_dst_addr)); + } + } + + return 0; +} + ClientConnectionImpl::ClientConnectionImpl(Event::DispatcherImpl& dispatcher, int fd, - const std::string& url) - : ConnectionImpl(dispatcher, fd, url) {} + const std::string& url, uint32_t port) + : ConnectionImpl(dispatcher, fd, url, port) {} Network::ClientConnectionPtr ClientConnectionImpl::create(Event::DispatcherImpl& dispatcher, const std::string& url) { @@ -412,7 +439,8 @@ Network::ClientConnectionPtr ClientConnectionImpl::create(Event::DispatcherImpl& TcpClientConnectionImpl::TcpClientConnectionImpl(Event::DispatcherImpl& dispatcher, const std::string& url) - : ClientConnectionImpl(dispatcher, socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0), url) {} + : ClientConnectionImpl(dispatcher, socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0), url, + Network::Utility::portFromUrl(url)) {} void TcpClientConnectionImpl::connect() { AddrInfoPtr addr_info = Utility::resolveTCP(Utility::hostFromUrl(remote_address_), @@ -422,7 +450,7 @@ void TcpClientConnectionImpl::connect() { UdsClientConnectionImpl::UdsClientConnectionImpl(Event::DispatcherImpl& dispatcher, const std::string& url) - : ClientConnectionImpl(dispatcher, socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0), url) {} + : ClientConnectionImpl(dispatcher, socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0), url, 0) {} void UdsClientConnectionImpl::connect() { sockaddr_un addr = Utility::resolveUnixDomainSocket(Utility::pathFromUrl(remote_address_)); diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index f0e3151728af..53951c9661cb 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -37,7 +37,8 @@ class ConnectionImpl : public virtual Connection, public BufferSource, protected Logger::Loggable { public: - ConnectionImpl(Event::DispatcherImpl& dispatcher, int fd, const std::string& remote_address); + ConnectionImpl(Event::DispatcherImpl& dispatcher, int fd, const std::string& remote_address, + uint32_t remote_port); ~ConnectionImpl(); // Network::FilterManager @@ -56,6 +57,9 @@ class ConnectionImpl : public virtual Connection, void readDisable(bool disable) override; bool readEnabled() override; const std::string& remoteAddress() override { return remote_address_; } + uint32_t remotePort() override { return remote_port_; }; + const std::string destinationAddress() override; + uint32_t destinationPort() override; void setBufferStats(const BufferStats& stats) override; Ssl::Connection* ssl() override { return nullptr; } State state() override; @@ -79,6 +83,7 @@ class ConnectionImpl : public virtual Connection, FilterManagerImpl filter_manager_; const std::string remote_address_; + uint32_t remote_port_; Buffer::OwnedImpl read_buffer_; Buffer::OwnedImpl write_buffer_; @@ -122,7 +127,8 @@ class ConnectionImpl : public virtual Connection, */ class ClientConnectionImpl : public ConnectionImpl, virtual public ClientConnection { public: - ClientConnectionImpl(Event::DispatcherImpl& dispatcher, int fd, const std::string& url); + ClientConnectionImpl(Event::DispatcherImpl& dispatcher, int fd, const std::string& url, + uint32_t port); static Network::ClientConnectionPtr create(Event::DispatcherImpl& dispatcher, const std::string& url); diff --git a/source/common/network/listener_impl.cc b/source/common/network/listener_impl.cc index cc95ca1baabf..1535b97f7fd5 100644 --- a/source/common/network/listener_impl.cc +++ b/source/common/network/listener_impl.cc @@ -77,12 +77,12 @@ void ListenerImpl::newConnection(int fd, sockaddr* addr) { if (use_proxy_proto_) { proxy_protocol_.newConnection(dispatcher_, fd, *this); } else { - newConnection(fd, getAddressName(addr)); + newConnection(fd, getAddressName(addr), getAddressPort(addr)); } } -void ListenerImpl::newConnection(int fd, const std::string& remote_address) { - ConnectionPtr new_connection(new ConnectionImpl(dispatcher_, fd, remote_address)); +void ListenerImpl::newConnection(int fd, const std::string& remote_address, uint32_t remote_port) { + ConnectionPtr new_connection(new ConnectionImpl(dispatcher_, fd, remote_address, remote_port)); cb_.onNewConnection(std::move(new_connection)); } @@ -90,12 +90,14 @@ void SslListenerImpl::newConnection(int fd, sockaddr* addr) { if (use_proxy_proto_) { proxy_protocol_.newConnection(dispatcher_, fd, *this); } else { - newConnection(fd, getAddressName(addr)); + newConnection(fd, getAddressName(addr), getAddressPort(addr)); } } -void SslListenerImpl::newConnection(int fd, const std::string& remote_address) { - ConnectionPtr new_connection(new Ssl::ConnectionImpl(dispatcher_, fd, remote_address, ssl_ctx_, +void SslListenerImpl::newConnection(int fd, const std::string& remote_address, + uint32_t remote_port) { + ConnectionPtr new_connection(new Ssl::ConnectionImpl(dispatcher_, fd, remote_address, remote_port, + ssl_ctx_, Ssl::ConnectionImpl::InitialState::Server)); cb_.onNewConnection(std::move(new_connection)); } diff --git a/source/common/network/listener_impl.h b/source/common/network/listener_impl.h index d82103970b24..1092166a4792 100644 --- a/source/common/network/listener_impl.h +++ b/source/common/network/listener_impl.h @@ -33,8 +33,9 @@ class ListenerImpl : public Listener { * Accept/process a new connection with the given remote address. * @param fd supplies the new connection's fd. * @param remote_address supplies the remote address for the new connection. + * @param remote_address supplies the remote port for the new connection. */ - virtual void newConnection(int fd, const std::string& remote_address); + virtual void newConnection(int fd, const std::string& remote_address, uint32_t remote_port); /** * @return the socket supplied to the listener at construction time @@ -73,7 +74,7 @@ class SslListenerImpl : public ListenerImpl { // ListenerImpl void newConnection(int fd, sockaddr* addr) override; - void newConnection(int fd, const std::string& remote_address) override; + void newConnection(int fd, const std::string& remote_address, uint32_t remote_port) override; private: Ssl::Context& ssl_ctx_; diff --git a/source/common/network/proxy_protocol.cc b/source/common/network/proxy_protocol.cc index 094317146c85..1a0f856f82e5 100644 --- a/source/common/network/proxy_protocol.cc +++ b/source/common/network/proxy_protocol.cc @@ -68,7 +68,10 @@ void ProxyProtocol::ActiveConnection::onReadWorker() { removeFromList(parent_.connections_); - listener.newConnection(fd, remote_address); + // Technically we could extract the remote port from the PROXY protocol header + // and pass it to the listener. However, the listener does not care about + // client-side ports for downstream connections, so we can just pass a 0 + listener.newConnection(fd, remote_address, 0); } void ProxyProtocol::ActiveConnection::close() { diff --git a/source/common/network/utility.cc b/source/common/network/utility.cc index 377f4159c9df..797c62a7e914 100644 --- a/source/common/network/utility.cc +++ b/source/common/network/utility.cc @@ -11,12 +11,8 @@ namespace Network { -IpWhiteList::IpWhiteList(const Json::Object& config) { - if (!config.hasObject("ip_white_list")) { - return; - } - - for (const std::string& entry : config.getStringArray("ip_white_list")) { +IpList::IpList(const std::vector& subnets) { + for (const std::string& entry : subnets) { std::vector parts = StringUtil::split(entry, '/'); if (parts.size() != 2) { throw EnvoyException( @@ -37,33 +33,32 @@ IpWhiteList::IpWhiteList(const Json::Object& config) { fmt::format("invalid ipv4/mask combo '{}' (mask bits must be <= 32)", entry)); } - Ipv4Entry white_list_entry; - white_list_entry.ipv4_address_ = ntohl(addr.s_addr); + Ipv4Entry list_entry; + list_entry.ipv4_address_ = ntohl(addr.s_addr); // The 1ULL below makes sure that the RHS is computed as a 64-bit value, so that we do not // over-shift to the left when mask = 0. The assignment to ipv4_mask_ then truncates // the value back to 32 bits. - white_list_entry.ipv4_mask_ = ~((1ULL << (32 - mask)) - 1); + list_entry.ipv4_mask_ = ~((1ULL << (32 - mask)) - 1); // Check to make sure applying the mask to the address equals the address. This can prevent // user error. - if ((white_list_entry.ipv4_address_ & white_list_entry.ipv4_mask_) != - white_list_entry.ipv4_address_) { + if ((list_entry.ipv4_address_ & list_entry.ipv4_mask_) != list_entry.ipv4_address_) { throw EnvoyException( fmt::format("invalid ipv4/mask combo '{}' ((address & mask) != address)", entry)); } - ipv4_white_list_.push_back(white_list_entry); + ipv4_list_.push_back(list_entry); } } -bool IpWhiteList::contains(const std::string& remote_address) const { +bool IpList::contains(const std::string& remote_address) const { in_addr addr; int rc = inet_pton(AF_INET, remote_address.c_str(), &addr); if (1 != rc) { return false; } - for (const Ipv4Entry& entry : ipv4_white_list_) { + for (const Ipv4Entry& entry : ipv4_list_) { if ((ntohl(addr.s_addr) & entry.ipv4_mask_) == entry.ipv4_address_) { return true; } @@ -72,6 +67,10 @@ bool IpWhiteList::contains(const std::string& remote_address) const { return false; } +IpWhiteList::IpWhiteList(const Json::Object& config) + : ip_list_(config.hasObject("ip_white_list") ? config.getStringArray("ip_white_list") + : std::vector()) {} + const std::string Utility::TCP_SCHEME = "tcp://"; const std::string Utility::UNIX_SCHEME = "unix://"; @@ -235,4 +234,35 @@ bool Utility::getOriginalDst(int fd, sockaddr_storage* orig_addr) { return (status == 0); } +void Utility::parsePortRangeList(const std::string& string, std::list& list) { + std::vector ranges = StringUtil::split(string.c_str(), ','); + for (const std::string& s : ranges) { + uint32_t min = 0; + uint32_t max = 0; + char dash = 0; + + std::stringstream ss(s); + + if (s.find('-') != std::string::npos) { + ss >> min; + ss >> dash; + ss >> max; + } else { + ss >> min; + max = min; + } + + list.emplace_back(PortRange(min, max)); + } +} + +bool Utility::portInRangeList(uint32_t port, const std::list& list) { + for (const Network::PortRange& p : list) { + if (p.contains(port)) { + return true; + } + } + return false; +} + } // Network diff --git a/source/common/network/utility.h b/source/common/network/utility.h index 85f5044c729a..2c1b3620fd4a 100644 --- a/source/common/network/utility.h +++ b/source/common/network/utility.h @@ -16,11 +16,13 @@ namespace Network { * Utility class for keeping a list of IPV4 addresses and masks, and then determining whether an * IP address is in the address/mask list. */ -class IpWhiteList { +class IpList { public: - IpWhiteList(const Json::Object& config); + IpList(const std::vector& subnets); + IpList(){}; - bool contains(const std::string& remote_address) const; + bool contains(const std::string& address) const; + bool empty() const { return ipv4_list_.empty(); } private: struct Ipv4Entry { @@ -28,9 +30,34 @@ class IpWhiteList { uint32_t ipv4_mask_; }; - std::vector ipv4_white_list_; + std::vector ipv4_list_; +}; + +class IpWhiteList { +public: + IpWhiteList(const Json::Object& config); + bool contains(const std::string& address) const { return ip_list_.contains(address); } + +private: + IpList ip_list_; +}; + +/** + * Utility class to represent TCP/UDP port range + */ +class PortRange { +public: + PortRange(uint32_t min, uint32_t max) : min_(min), max_(max) {} + + bool contains(uint32_t port) const { return (port >= min_ && port <= max_); } + +private: + uint32_t min_; + uint32_t max_; }; +typedef std::list PortRangeList; + /** * Common network utility routines. */ @@ -130,6 +157,22 @@ class Utility { * @return true if the operation succeeded, false otherwise */ static bool getOriginalDst(int fd, sockaddr_storage* orig_addr); + + /** + * Parses a string containing a comma-separated list of port numbers and/or + * port ranges and appends the values to a caller-provided list of PortRange structures. + * @param str is the string containing the port numbers and ranges + * @param list is the list to append the new data structures to + */ + static void parsePortRangeList(const std::string& string, std::list& list); + + /** + * Checks whether a given port number appears in at least one of the port ranges in a list + * @param port is the port number to search + * @param list the list of port ranges in which the port may appear + * @return whether the port appears in at least one of the ranges in the list + */ + static bool portInRangeList(uint32_t port, const std::list& list); }; } // Network diff --git a/source/common/ssl/connection_impl.cc b/source/common/ssl/connection_impl.cc index 901e49e4268a..77f402a3068d 100644 --- a/source/common/ssl/connection_impl.cc +++ b/source/common/ssl/connection_impl.cc @@ -10,8 +10,9 @@ namespace Ssl { ConnectionImpl::ConnectionImpl(Event::DispatcherImpl& dispatcher, int fd, - const std::string& remote_address, Context& ctx, InitialState state) - : Network::ConnectionImpl(dispatcher, fd, remote_address), + const std::string& remote_address, uint32_t remote_port, + Context& ctx, InitialState state) + : Network::ConnectionImpl(dispatcher, fd, remote_address, remote_port), ctx_(dynamic_cast(ctx)), ssl_(ctx_.newSsl()) { BIO* bio = BIO_new_socket(fd, 0); SSL_set_bio(ssl_.get(), bio, bio); @@ -187,8 +188,8 @@ std::string ConnectionImpl::sha256PeerCertificateDigest() { ClientConnectionImpl::ClientConnectionImpl(Event::DispatcherImpl& dispatcher, Context& ctx, const std::string& url) - : ConnectionImpl(dispatcher, socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0), url, ctx, - InitialState::Client) {} + : ConnectionImpl(dispatcher, socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0), url, + Network::Utility::portFromUrl(url), ctx, InitialState::Client) {} void ClientConnectionImpl::connect() { Network::AddrInfoPtr addr_info = diff --git a/source/common/ssl/connection_impl.h b/source/common/ssl/connection_impl.h index 0f38d29275b9..ea0118818786 100644 --- a/source/common/ssl/connection_impl.h +++ b/source/common/ssl/connection_impl.h @@ -11,7 +11,7 @@ class ConnectionImpl : public Network::ConnectionImpl, public Connection { enum class InitialState { Client, Server }; ConnectionImpl(Event::DispatcherImpl& dispatcher, int fd, const std::string& remote_address, - Context& ctx, InitialState state); + uint32_t remote_port, Context& ctx, InitialState state); ~ConnectionImpl(); // Network::Connection diff --git a/test/common/filter/tcp_proxy_test.cc b/test/common/filter/tcp_proxy_test.cc index 68e322081f39..53b4b4b23244 100644 --- a/test/common/filter/tcp_proxy_test.cc +++ b/test/common/filter/tcp_proxy_test.cc @@ -12,18 +12,39 @@ using testing::_; using testing::NiceMock; using testing::Return; +using testing::ReturnRefOfCopy; using testing::SaveArg; namespace Filter { -TEST(TcpProxyConfigTest, NoCluster) { +TEST(TcpProxyConfigTest, NoRouteConfig) { std::string json = R"EOF( { - "cluster": "fake_cluster", "stat_prefix": "name" } )EOF"; + Json::ObjectPtr config = Json::Factory::LoadFromString(json); + NiceMock cluster_manager; + EXPECT_THROW( + TcpProxyConfig(*config, cluster_manager, cluster_manager.cluster_.info_->stats_store_), + EnvoyException); +} + +TEST(TcpProxyConfigTest, NoCluster) { + std::string json = R"EOF( + { + "stat_prefix": "name", + "route_config": { + "routes": [ + { + "cluster": "fake_cluster" + } + ] + } + } + )EOF"; + Json::ObjectPtr config = Json::Factory::LoadFromString(json); NiceMock cluster_manager; EXPECT_CALL(cluster_manager, get("fake_cluster")).WillOnce(Return(nullptr)); @@ -32,13 +53,217 @@ TEST(TcpProxyConfigTest, NoCluster) { EnvoyException); } +TEST(TcpProxyConfigTest, Routes) { + std::string json = R"EOF( + { + "stat_prefix": "name", + "route_config": { + "routes": [ + { + "destination_ip_list": [ + "10.10.10.10/32", + "10.10.11.0/24", + "10.11.0.0/16", + "11.0.0.0/8", + "128.0.0.0/1" + ], + "cluster": "with_destination_ip_list" + }, + { + "destination_ports": "1-1024,2048-4096,12345", + "cluster": "with_destination_ports" + }, + { + "source_ports": "23457,23459", + "cluster": "with_source_ports" + }, + { + "destination_ip_list": [ + "10.0.0.0/24" + ], + "source_ip_list": [ + "20.0.0.0/24" + ], + "destination_ports" : "10000", + "source_ports": "20000", + "cluster": "with_everything" + }, + { + "cluster": "catch_all" + } + ] + } + } + )EOF"; + + Json::ObjectPtr json_config = Json::Factory::LoadFromString(json); + NiceMock cm_; + + // The TcpProxyConfig constructor checks if the clusters mentioned in the route_config are valid. + // We need to make sure to return a non-null pointer for each, otherwise the constructor will + // throw an exception and fail. + EXPECT_CALL(cm_, get("with_destination_ip_list")).WillRepeatedly(Return(cm_.cluster_.info_)); + EXPECT_CALL(cm_, get("with_destination_ports")).WillRepeatedly(Return(cm_.cluster_.info_)); + EXPECT_CALL(cm_, get("with_source_ports")).WillRepeatedly(Return(cm_.cluster_.info_)); + EXPECT_CALL(cm_, get("with_everything")).WillRepeatedly(Return(cm_.cluster_.info_)); + EXPECT_CALL(cm_, get("catch_all")).WillRepeatedly(Return(cm_.cluster_.info_)); + + TcpProxyConfig config_obj(*json_config, cm_, cm_.cluster_.info_->stats_store_); + + { + // hit route with destination_ip (10.10.10.10/32) + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("10.10.10.10")); + EXPECT_EQ(std::string("with_destination_ip_list"), config_obj.getClusterForConnection(connection)); + } + + { + // fall-through + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("10.10.10.11")); + EXPECT_EQ(std::string("catch_all"), config_obj.getClusterForConnection(connection)); + } + + { + // hit route with destination_ip (10.10.11.0/24) + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("10.10.11.11")); + EXPECT_EQ(std::string("with_destination_ip_list"), config_obj.getClusterForConnection(connection)); + } + + { + // fall-through + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("10.10.12.12")); + EXPECT_EQ(std::string("catch_all"), config_obj.getClusterForConnection(connection)); + } + + { + // hit route with destination_ip (10.11.0.0/16) + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("10.11.11.11")); + EXPECT_EQ(std::string("with_destination_ip_list"), config_obj.getClusterForConnection(connection)); + } + + { + // fall-through + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("10.12.12.12")); + EXPECT_EQ(std::string("catch_all"), config_obj.getClusterForConnection(connection)); + } + + { + // hit route with destination_ip (11.0.0.0/8) + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("11.11.11.11")); + EXPECT_EQ(std::string("with_destination_ip_list"), config_obj.getClusterForConnection(connection)); + } + + { + // fall-through + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("12.12.12.12")); + EXPECT_EQ(std::string("catch_all"), config_obj.getClusterForConnection(connection)); + } + + { + // hit route with destination_ip (128.0.0.0/8) + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("128.255.255.255")); + EXPECT_EQ(std::string("with_destination_ip_list"), config_obj.getClusterForConnection(connection)); + } + + { + // hit route with destination port range + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("1.2.3.4")); + EXPECT_CALL(connection, destinationPort()).WillRepeatedly(Return(12345)); + EXPECT_EQ(std::string("with_destination_ports"), + config_obj.getClusterForConnection(connection)); + } + + { + // fall through + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("1.2.3.4")); + EXPECT_CALL(connection, destinationPort()).WillRepeatedly(Return(23456)); + EXPECT_EQ(std::string("catch_all"), config_obj.getClusterForConnection(connection)); + } + + { + // hit route with source port range + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("1.2.3.4")); + EXPECT_CALL(connection, destinationPort()).WillRepeatedly(Return(23456)); + EXPECT_CALL(connection, remotePort()).WillRepeatedly(Return(23459)); + EXPECT_EQ(std::string("with_source_ports"), config_obj.getClusterForConnection(connection)); + } + + { + // fall through + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("1.2.3.4")); + EXPECT_CALL(connection, destinationPort()).WillRepeatedly(Return(23456)); + EXPECT_CALL(connection, remotePort()).WillRepeatedly(Return(23458)); + EXPECT_EQ(std::string("catch_all"), config_obj.getClusterForConnection(connection)); + } + + { + // hit the route with all criterias present + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("10.0.0.0")); + EXPECT_CALL(connection, destinationPort()).WillRepeatedly(Return(10000)); + EXPECT_CALL(connection, remoteAddress()) + .WillRepeatedly(ReturnRefOfCopy(std::string("20.0.0.0"))); + EXPECT_CALL(connection, remotePort()).WillRepeatedly(Return(20000)); + EXPECT_EQ(std::string("with_everything"), config_obj.getClusterForConnection(connection)); + } + + { + // fall through + NiceMock connection; + EXPECT_CALL(connection, destinationAddress()).WillRepeatedly(Return("10.0.0.0")); + EXPECT_CALL(connection, destinationPort()).WillRepeatedly(Return(10000)); + EXPECT_CALL(connection, remoteAddress()) + .WillRepeatedly(ReturnRefOfCopy(std::string("30.0.0.0"))); + EXPECT_CALL(connection, remotePort()).WillRepeatedly(Return(20000)); + EXPECT_EQ(std::string("catch_all"), config_obj.getClusterForConnection(connection)); + } +} + +TEST(TcpProxyConfigTest, EmptyRouteConfig) { + std::string json = R"EOF( + { + "stat_prefix": "name", + "route_config": { + "routes": [ + ] + } + } + )EOF"; + + Json::ObjectPtr json_config = Json::Factory::LoadFromString(json); + NiceMock cm_; + + TcpProxyConfig config_obj(*json_config, cm_, cm_.cluster_.info_->stats_store_); + + NiceMock connection; + EXPECT_EQ(std::string(""), config_obj.getClusterForConnection(connection)); +} + class TcpProxyTest : public testing::Test { public: TcpProxyTest() { std::string json = R"EOF( { - "cluster": "fake_cluster", - "stat_prefix": "name" + "stat_prefix": "name", + "route_config": { + "routes": [ + { + "cluster": "fake_cluster" + } + ] + } } )EOF"; diff --git a/test/common/network/connection_impl_test.cc b/test/common/network/connection_impl_test.cc index 950214050751..0fb75d9fc659 100644 --- a/test/common/network/connection_impl_test.cc +++ b/test/common/network/connection_impl_test.cc @@ -43,7 +43,7 @@ TEST(ConnectionImplUtility, updateBufferStats) { TEST(ConnectionImplDeathTest, BadFd) { Event::DispatcherImpl dispatcher; - EXPECT_DEATH(ConnectionImpl(dispatcher, -1, "127.0.0.1"), ".*assert failure: fd_ != -1.*"); + EXPECT_DEATH(ConnectionImpl(dispatcher, -1, "127.0.0.1", 0), ".*assert failure: fd_ != -1.*"); } struct MockBufferStats { diff --git a/test/common/network/filter_manager_impl_test.cc b/test/common/network/filter_manager_impl_test.cc index cbd2d6460e24..4bb767e687f4 100644 --- a/test/common/network/filter_manager_impl_test.cc +++ b/test/common/network/filter_manager_impl_test.cc @@ -122,8 +122,14 @@ TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) { std::string tcp_proxy_json = R"EOF( { - "cluster": "fake_cluster", - "stat_prefix": "name" + "stat_prefix": "name", + "route_config": { + "routes": [ + { + "cluster": "fake_cluster" + } + ] + } } )EOF"; diff --git a/test/common/network/listener_impl_test.cc b/test/common/network/listener_impl_test.cc index 72bd5020424c..d33770c781d5 100644 --- a/test/common/network/listener_impl_test.cc +++ b/test/common/network/listener_impl_test.cc @@ -52,8 +52,8 @@ class TestListenerImpl : public ListenerImpl { MOCK_METHOD2(newConnection_, void(int, sockaddr*)); void newConnection(int fd, sockaddr* addr) override { newConnection_(fd, addr); } - void newConnection(int fd, const std::string& addr) override { - ListenerImpl::newConnection(fd, addr); + void newConnection(int fd, const std::string& addr, uint32_t port) override { + ListenerImpl::newConnection(fd, addr, port); } protected: diff --git a/test/common/network/utility_test.cc b/test/common/network/utility_test.cc index 9b4226a7d3c1..bfcda90253d2 100644 --- a/test/common/network/utility_test.cc +++ b/test/common/network/utility_test.cc @@ -138,4 +138,42 @@ TEST(NetworkUtility, loopbackAddress) { } } +TEST(NetworkUtility, PortRangeList) { + { + std::string port_range_str = "1"; + std::list port_range_list; + + Utility::parsePortRangeList(port_range_str, port_range_list); + EXPECT_TRUE(Utility::portInRangeList(1, port_range_list)); + EXPECT_FALSE(Utility::portInRangeList(2, port_range_list)); + } + + { + std::string port_range_str = "1024-2048"; + std::list port_range_list; + + Utility::parsePortRangeList(port_range_str, port_range_list); + EXPECT_TRUE(Utility::portInRangeList(1024, port_range_list)); + EXPECT_TRUE(Utility::portInRangeList(2048, port_range_list)); + EXPECT_TRUE(Utility::portInRangeList(1536, port_range_list)); + EXPECT_FALSE(Utility::portInRangeList(1023, port_range_list)); + EXPECT_FALSE(Utility::portInRangeList(2049, port_range_list)); + EXPECT_FALSE(Utility::portInRangeList(0, port_range_list)); + } + + { + std::string port_range_str = "1,10-100,1000-10000,65535"; + std::list port_range_list; + + Utility::parsePortRangeList(port_range_str, port_range_list); + EXPECT_TRUE(Utility::portInRangeList(1, port_range_list)); + EXPECT_TRUE(Utility::portInRangeList(50, port_range_list)); + EXPECT_TRUE(Utility::portInRangeList(5000, port_range_list)); + EXPECT_TRUE(Utility::portInRangeList(65535, port_range_list)); + EXPECT_FALSE(Utility::portInRangeList(2, port_range_list)); + EXPECT_FALSE(Utility::portInRangeList(200, port_range_list)); + EXPECT_FALSE(Utility::portInRangeList(20000, port_range_list)); + } +} + } // Network diff --git a/test/config/integration/server.json b/test/config/integration/server.json index 531684730de1..1e639211ec8a 100644 --- a/test/config/integration/server.json +++ b/test/config/integration/server.json @@ -188,7 +188,16 @@ "filters": [ { "type": "read", "name": "tcp_proxy", - "config": { "cluster": "cluster_1", "stat_prefix": "test_tcp" } + "config": { + "stat_prefix": "test_tcp", + "route_config": { + "routes": [ + { + "cluster": "cluster_1" + } + ] + } + } } ] }], diff --git a/test/config/integration/server_http2.json b/test/config/integration/server_http2.json index 992fd6e456ab..f7e12a9ecc8c 100644 --- a/test/config/integration/server_http2.json +++ b/test/config/integration/server_http2.json @@ -157,7 +157,16 @@ "filters": [ { "type": "read", "name": "tcp_proxy", - "config": { "cluster": "cluster_1", "stat_prefix": "test_tcp" } + "config": { + "stat_prefix": "test_tcp", + "route_config": { + "routes": [ + { + "cluster": "cluster_1" + } + ] + } + } } ] }], diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 60d05ac73ba9..1877c6fffc3d 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -52,6 +52,9 @@ class MockConnection : public Connection, public MockConnectionBase { MOCK_METHOD1(readDisable, void(bool disable)); MOCK_METHOD0(readEnabled, bool()); MOCK_METHOD0(remoteAddress, const std::string&()); + MOCK_METHOD0(remotePort, uint32_t()); + MOCK_METHOD0(destinationAddress, const std::string()); + MOCK_METHOD0(destinationPort, uint32_t()); MOCK_METHOD1(setBufferStats, void(const BufferStats& stats)); MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_METHOD0(state, State()); @@ -81,6 +84,9 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase MOCK_METHOD1(readDisable, void(bool disable)); MOCK_METHOD0(readEnabled, bool()); MOCK_METHOD0(remoteAddress, const std::string&()); + MOCK_METHOD0(remotePort, uint32_t()); + MOCK_METHOD0(destinationAddress, const std::string()); + MOCK_METHOD0(destinationPort, uint32_t()); MOCK_METHOD1(setBufferStats, void(const BufferStats& stats)); MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_METHOD0(state, State());