diff --git a/include/envoy/redis/codec.h b/include/envoy/redis/codec.h index edfc218f164b..c4ca0cd3560b 100644 --- a/include/envoy/redis/codec.h +++ b/include/envoy/redis/codec.h @@ -19,6 +19,11 @@ class RespValue { RespValue() : type_(RespType::Null) {} ~RespValue() { cleanup(); } + /** + * Convert a RESP value to a string for debugging purposes. + */ + std::string toString() const; + /** * The following are getters and setters for the internal value. A RespValue start as null, * and much change type via type() before the following methods can be used. diff --git a/include/envoy/redis/command_splitter.h b/include/envoy/redis/command_splitter.h new file mode 100644 index 000000000000..8fae618cb3df --- /dev/null +++ b/include/envoy/redis/command_splitter.h @@ -0,0 +1,58 @@ +#pragma once + +#include "envoy/common/pure.h" +#include "envoy/redis/codec.h" + +namespace Redis { +namespace CommandSplitter { + +/** + * A handle to a split request. + */ +class SplitRequest { +public: + virtual ~SplitRequest() {} + + /** + * Cancel the request. No further request callbacks will be called. + */ + virtual void cancel() PURE; +}; + +typedef std::unique_ptr SplitRequestPtr; + +/** + * Split request callbacks. + */ +class SplitCallbacks { +public: + virtual ~SplitCallbacks() {} + + /** + * Called when the response is ready. + * @param value supplies the response which is now owned by the callee. + */ + virtual void onResponse(RespValuePtr&& value) PURE; +}; + +/** + * A command splitter that takes incoming redis commands and splits them as appropriate to a + * backend connection pool. + */ +class Instance { +public: + virtual ~Instance() {} + + /** + * Make a split redis request. + * @param request supplies the split request to make. + * @param callbacks supplies the split request completion callbacks. + * @return SplitRequestPtr a handle to the active request or nullptr if the request has already + * been satisfied (via onResponse() being called). The splitter ALWAYS calls + * onResponse() for a given request. + */ + virtual SplitRequestPtr makeRequest(const RespValue& request, SplitCallbacks& callbacks) PURE; +}; + +} // CommandSplitter +} // Redis diff --git a/include/envoy/redis/conn_pool.h b/include/envoy/redis/conn_pool.h index a783a0652b75..29d2cc515c54 100644 --- a/include/envoy/redis/conn_pool.h +++ b/include/envoy/redis/conn_pool.h @@ -9,9 +9,9 @@ namespace ConnPool { /** * A handle to an outbound request. */ -class ActiveRequest { +class PoolRequest { public: - virtual ~ActiveRequest() {} + virtual ~PoolRequest() {} /** * Cancel the request. No further request callbacks will be called. @@ -22,9 +22,9 @@ class ActiveRequest { /** * Outbound request callbacks. */ -class ActiveRequestCallbacks { +class PoolCallbacks { public: - virtual ~ActiveRequestCallbacks() {} + virtual ~PoolCallbacks() {} /** * Called when a pipelined response is received. @@ -59,10 +59,10 @@ class Client { * Make a pipelined request to the remote redis server. * @param request supplies the RESP request to make. * @param callbacks supplies the request callbacks. - * @return ActiveRequest* a handle to the active request. + * @return PoolRequest* a handle to the active request or nullptr if the request could not be made + * for some reason. */ - virtual ActiveRequest* makeRequest(const RespValue& request, - ActiveRequestCallbacks& callbacks) PURE; + virtual PoolRequest* makeRequest(const RespValue& request, PoolCallbacks& callbacks) PURE; }; typedef std::unique_ptr ClientPtr; @@ -93,12 +93,14 @@ class Instance { * @param hash_key supplies the key to use for consistent hashing. * @param request supplies the request to make. * @param callbacks supplies the request completion callbacks. - * @return ActiveRequest* a handle to the active request or nullptr if the request could not - * be made for some reason. + * @return PoolRequest* a handle to the active request or nullptr if the request could not be made + * for some reason. */ - virtual ActiveRequest* makeRequest(const std::string& hash_key, const RespValue& request, - ActiveRequestCallbacks& callbacks) PURE; + virtual PoolRequest* makeRequest(const std::string& hash_key, const RespValue& request, + PoolCallbacks& callbacks) PURE; }; +typedef std::unique_ptr InstancePtr; + } // ConnPool } // Redis diff --git a/source/common/CMakeLists.txt b/source/common/CMakeLists.txt index efa3d490eae7..ce9ea91c2e29 100644 --- a/source/common/CMakeLists.txt +++ b/source/common/CMakeLists.txt @@ -88,6 +88,7 @@ add_library( profiler/profiler.cc ratelimit/ratelimit_impl.cc redis/codec_impl.cc + redis/command_splitter_impl.cc redis/conn_pool_impl.cc redis/proxy_filter.cc router/config_impl.cc diff --git a/source/common/http/filter/ratelimit.h b/source/common/http/filter/ratelimit.h index 64994fc58da3..437086e7de07 100644 --- a/source/common/http/filter/ratelimit.h +++ b/source/common/http/filter/ratelimit.h @@ -23,11 +23,11 @@ enum class FilterRequestType { Internal, External, Both }; /** * Global configuration for the HTTP rate limit filter. */ -class FilterConfig : Json::JsonValidator { +class FilterConfig : Json::Validator { public: FilterConfig(const Json::Object& config, const LocalInfo::LocalInfo& local_info, Stats::Store& global_store, Runtime::Loader& runtime, Upstream::ClusterManager& cm) - : Json::JsonValidator(config, Json::Schema::RATE_LIMIT_HTTP_FILTER_SCHEMA), + : Json::Validator(config, Json::Schema::RATE_LIMIT_HTTP_FILTER_SCHEMA), domain_(config.getString("domain")), stage_(static_cast(config.getInteger("stage", 0))), request_type_(stringToType(config.getString("request_type", "both"))), diff --git a/source/common/json/json_validator.h b/source/common/json/json_validator.h index c01c972522fd..0e957746ec93 100644 --- a/source/common/json/json_validator.h +++ b/source/common/json/json_validator.h @@ -7,9 +7,9 @@ namespace Json { /** * Base class to inherit from to validate config schema before initializing member variables. */ -class JsonValidator { +class Validator { public: - JsonValidator(const Json::Object& config, const std::string& schema) { + Validator(const Json::Object& config, const std::string& schema) { config.validateSchema(schema); } }; diff --git a/source/common/redis/codec_impl.cc b/source/common/redis/codec_impl.cc index b350b5356b00..802bc8ad4ffc 100644 --- a/source/common/redis/codec_impl.cc +++ b/source/common/redis/codec_impl.cc @@ -5,6 +5,31 @@ namespace Redis { +std::string RespValue::toString() const { + switch (type_) { + case RespType::Array: { + std::string ret = "["; + for (uint64_t i = 0; i < asArray().size(); i++) { + ret += asArray()[i].toString(); + if (i != asArray().size() - 1) { + ret += ", "; + } + } + return ret + "]"; + } + case RespType::SimpleString: + case RespType::BulkString: + case RespType::Error: + return fmt::format("\"{}\"", asString()); + case RespType::Null: + return "null"; + case RespType::Integer: + return std::to_string(asInteger()); + } + + NOT_REACHED; +} + std::vector& RespValue::asArray() { ASSERT(type_ == RespType::Array); return array_; diff --git a/source/common/redis/command_splitter_impl.cc b/source/common/redis/command_splitter_impl.cc new file mode 100644 index 000000000000..72902c2fed42 --- /dev/null +++ b/source/common/redis/command_splitter_impl.cc @@ -0,0 +1,170 @@ +#include "command_splitter_impl.h" + +#include "common/common/assert.h" + +namespace Redis { +namespace CommandSplitter { + +RespValuePtr Utility::makeError(const std::string& error) { + RespValuePtr response(new RespValue()); + response->type(RespType::Error); + response->asString() = error; + return response; +} + +SplitRequestPtr AllParamsToOneServerCommandHandler::startRequest(const RespValue& request, + SplitCallbacks& callbacks) { + std::unique_ptr request_handle(new SplitRequestImpl(callbacks)); + request_handle->handle_ = + conn_pool_.makeRequest(request.asArray()[1].asString(), request, *request_handle); + if (!request_handle->handle_) { + callbacks.onResponse(Utility::makeError("no upstream host")); + return nullptr; + } + + return std::move(request_handle); +} + +AllParamsToOneServerCommandHandler::SplitRequestImpl::~SplitRequestImpl() { ASSERT(!handle_); } + +void AllParamsToOneServerCommandHandler::SplitRequestImpl::cancel() { + handle_->cancel(); + handle_ = nullptr; +} + +void AllParamsToOneServerCommandHandler::SplitRequestImpl::onResponse(RespValuePtr&& response) { + handle_ = nullptr; + log_debug("redis: response: '{}'", response->toString()); + callbacks_.onResponse(std::move(response)); +} + +void AllParamsToOneServerCommandHandler::SplitRequestImpl::onFailure() { + handle_ = nullptr; + callbacks_.onResponse(Utility::makeError("upstream failure")); +} + +SplitRequestPtr MGETCommandHandler::startRequest(const RespValue& request, + SplitCallbacks& callbacks) { + std::unique_ptr request_handle( + new SplitRequestImpl(callbacks, request.asArray().size() - 1)); + + // Create the get request that we will use for each split get below. + std::vector values(2); + values[0].type(RespType::BulkString); + values[0].asString() = "get"; + values[1].type(RespType::BulkString); + RespValue single_mget; + single_mget.type(RespType::Array); + single_mget.asArray().swap(values); + + for (uint64_t i = 1; i < request.asArray().size(); i++) { + request_handle->pending_requests_.emplace_back(*request_handle, i - 1); + SplitRequestImpl::PendingRequest& pending_request = request_handle->pending_requests_.back(); + + single_mget.asArray()[1].asString() = request.asArray()[i].asString(); + log_debug("redis: parallel get: '{}'", single_mget.toString()); + pending_request.handle_ = + conn_pool_.makeRequest(request.asArray()[i].asString(), single_mget, pending_request); + if (!pending_request.handle_) { + pending_request.onResponse(Utility::makeError("no upstream host")); + } + } + + return request_handle->pending_responses_ > 0 ? std::move(request_handle) : nullptr; +} + +MGETCommandHandler::SplitRequestImpl::SplitRequestImpl(SplitCallbacks& callbacks, + uint32_t num_responses) + : callbacks_(callbacks), pending_responses_(num_responses) { + pending_response_.reset(new RespValue()); + pending_response_->type(RespType::Array); + std::vector responses(num_responses); + pending_response_->asArray().swap(responses); + pending_requests_.reserve(num_responses); +} + +MGETCommandHandler::SplitRequestImpl::~SplitRequestImpl() { +#ifndef NDEBUG + for (const PendingRequest& request : pending_requests_) { + ASSERT(!request.handle_); + } +#endif +} + +void MGETCommandHandler::SplitRequestImpl::cancel() { + for (PendingRequest& request : pending_requests_) { + if (request.handle_) { + request.handle_->cancel(); + request.handle_ = nullptr; + } + } +} + +void MGETCommandHandler::SplitRequestImpl::onResponse(RespValuePtr&& value, uint32_t index) { + pending_requests_[index].handle_ = nullptr; + + pending_response_->asArray()[index].type(value->type()); + switch (value->type()) { + case RespType::Array: + case RespType::Integer: { + pending_response_->asArray()[index].type(RespType::Error); + pending_response_->asArray()[index].asString() = "upstream protocol error"; + break; + } + case RespType::SimpleString: + case RespType::BulkString: + case RespType::Error: { + pending_response_->asArray()[index].asString().swap(value->asString()); + break; + } + case RespType::Null: + break; + } + + ASSERT(pending_responses_ > 0); + if (--pending_responses_ == 0) { + log_debug("redis: response: '{}'", pending_response_->toString()); + callbacks_.onResponse(std::move(pending_response_)); + } +} + +void MGETCommandHandler::SplitRequestImpl::onFailure(uint32_t index) { + onResponse(Utility::makeError("upstream failure"), index); +} + +InstanceImpl::InstanceImpl(ConnPool::InstancePtr&& conn_pool) + : conn_pool_(std::move(conn_pool)), all_to_one_handler_(*conn_pool_), + mget_handler_(*conn_pool_) { + // TODO(mattklein123) PERF: Make this a trie (like in header_map_impl). + // TODO(mattklein123): Make not case sensitive (like in header_map_impl). + command_map_.emplace("incr", all_to_one_handler_); + command_map_.emplace("incrby", all_to_one_handler_); + command_map_.emplace("mget", mget_handler_); +} + +SplitRequestPtr InstanceImpl::makeRequest(const RespValue& request, SplitCallbacks& callbacks) { + if (request.type() != RespType::Array || request.asArray().size() < 2) { + callbacks.onResponse(Utility::makeError("invalid request")); + return nullptr; + } + + for (const RespValue& value : request.asArray()) { + if (value.type() != RespType::BulkString) { + callbacks.onResponse(Utility::makeError("invalid request")); + return nullptr; + } + } + + auto handler = command_map_.find(request.asArray()[0].asString()); + if (handler == command_map_.end()) { + callbacks.onResponse(Utility::makeError( + fmt::format("unsupported command '{}'", request.asArray()[0].asString()))); + return nullptr; + } + + log_debug("redis: splitting '{}'", request.toString()); + return handler->second.get().startRequest(request, callbacks); +} + +} // CommandSplitter +} // Redis diff --git a/source/common/redis/command_splitter_impl.h b/source/common/redis/command_splitter_impl.h new file mode 100644 index 000000000000..f787777790b5 --- /dev/null +++ b/source/common/redis/command_splitter_impl.h @@ -0,0 +1,113 @@ +#pragma once + +#include "envoy/redis/command_splitter.h" +#include "envoy/redis/conn_pool.h" + +#include "common/common/logger.h" + +namespace Redis { +namespace CommandSplitter { + +class Utility { +public: + static RespValuePtr makeError(const std::string& error); +}; + +class CommandHandler { +public: + virtual ~CommandHandler() {} + + virtual SplitRequestPtr startRequest(const RespValue& request, SplitCallbacks& callbacks) PURE; +}; + +class CommandHandlerBase { +protected: + CommandHandlerBase(ConnPool::Instance& conn_pool) : conn_pool_(conn_pool) {} + + ConnPool::Instance& conn_pool_; +}; + +class AllParamsToOneServerCommandHandler : public CommandHandler, + CommandHandlerBase, + Logger::Loggable { +public: + AllParamsToOneServerCommandHandler(ConnPool::Instance& conn_pool) + : CommandHandlerBase(conn_pool) {} + + // Redis::CommandSplitter::CommandHandler + SplitRequestPtr startRequest(const RespValue& request, SplitCallbacks& callbacks) override; + +private: + struct SplitRequestImpl : public SplitRequest, public ConnPool::PoolCallbacks { + SplitRequestImpl(SplitCallbacks& callbacks) : callbacks_(callbacks) {} + ~SplitRequestImpl(); + + // Redis::CommandSplitter::SplitRequest + void cancel() override; + + // Redis::ConnPool::PoolCallbacks + void onResponse(RespValuePtr&& value) override; + void onFailure() override; + + SplitCallbacks& callbacks_; + ConnPool::PoolRequest* handle_{}; + }; +}; + +class MGETCommandHandler : public CommandHandler, + CommandHandlerBase, + Logger::Loggable { +public: + MGETCommandHandler(ConnPool::Instance& conn_pool) : CommandHandlerBase(conn_pool) {} + + // Redis::CommandSplitter::CommandHandler + SplitRequestPtr startRequest(const RespValue& request, SplitCallbacks& callbacks) override; + +private: + struct SplitRequestImpl : public SplitRequest { + struct PendingRequest : public ConnPool::PoolCallbacks { + PendingRequest(SplitRequestImpl& parent, uint32_t index) : parent_(parent), index_(index) {} + + // Redis::ConnPool::PoolCallbacks + void onResponse(RespValuePtr&& value) override { + parent_.onResponse(std::move(value), index_); + } + void onFailure() override { parent_.onFailure(index_); } + + SplitRequestImpl& parent_; + const uint32_t index_; + ConnPool::PoolRequest* handle_{}; + }; + + SplitRequestImpl(SplitCallbacks& callbacks, uint32_t num_responses); + ~SplitRequestImpl(); + + void onResponse(RespValuePtr&& value, uint32_t index); + void onFailure(uint32_t index); + + // Redis::CommandSplitter::SplitRequest + void cancel() override; + + SplitCallbacks& callbacks_; + RespValuePtr pending_response_; + std::vector pending_requests_; + uint32_t pending_responses_; + }; +}; + +class InstanceImpl : public Instance, Logger::Loggable { +public: + InstanceImpl(ConnPool::InstancePtr&& conn_pool); + + // Redis::CommandSplitter::Instance + SplitRequestPtr makeRequest(const RespValue& request, SplitCallbacks& callbacks) override; + +private: + ConnPool::InstancePtr conn_pool_; + AllParamsToOneServerCommandHandler all_to_one_handler_; + MGETCommandHandler mget_handler_; + std::unordered_map> command_map_; +}; + +} // CommandSplitter +} // Redis diff --git a/source/common/redis/conn_pool_impl.cc b/source/common/redis/conn_pool_impl.cc index 35c55cedbcdc..f971ce4c41a5 100644 --- a/source/common/redis/conn_pool_impl.cc +++ b/source/common/redis/conn_pool_impl.cc @@ -24,8 +24,7 @@ ClientImpl::~ClientImpl() { void ClientImpl::close() { connection_->close(Network::ConnectionCloseType::NoFlush); } -ActiveRequest* ClientImpl::makeRequest(const RespValue& request, - ActiveRequestCallbacks& callbacks) { +PoolRequest* ClientImpl::makeRequest(const RespValue& request, PoolCallbacks& callbacks) { ASSERT(connection_->state() == Network::Connection::State::Open); pending_requests_.emplace_back(callbacks); encoder_->encode(request, encoder_buffer_); @@ -86,8 +85,8 @@ InstanceImpl::InstanceImpl(const std::string& cluster_name, Upstream::ClusterMan }); } -ActiveRequest* InstanceImpl::makeRequest(const std::string& hash_key, const RespValue& value, - ActiveRequestCallbacks& callbacks) { +PoolRequest* InstanceImpl::makeRequest(const std::string& hash_key, const RespValue& value, + PoolCallbacks& callbacks) { return tls_.getTyped(tls_slot_).makeRequest(hash_key, value, callbacks); } @@ -113,9 +112,9 @@ void InstanceImpl::ThreadLocalPool::onHostsRemoved( } } -ActiveRequest* InstanceImpl::ThreadLocalPool::makeRequest(const std::string& hash_key, - const RespValue& request, - ActiveRequestCallbacks& callbacks) { +PoolRequest* InstanceImpl::ThreadLocalPool::makeRequest(const std::string& hash_key, + const RespValue& request, + PoolCallbacks& callbacks) { LbContextImpl lb_context(hash_key); Upstream::HostConstSharedPtr host = cluster_->loadBalancer().chooseHost(&lb_context); if (!host) { diff --git a/source/common/redis/conn_pool_impl.h b/source/common/redis/conn_pool_impl.h index e22b92b7c88f..34eb896e99cd 100644 --- a/source/common/redis/conn_pool_impl.h +++ b/source/common/redis/conn_pool_impl.h @@ -28,7 +28,7 @@ class ClientImpl : public Client, public DecoderCallbacks, public Network::Conne connection_->addConnectionCallbacks(callbacks); } void close() override; - ActiveRequest* makeRequest(const RespValue& request, ActiveRequestCallbacks& callbacks) override; + PoolRequest* makeRequest(const RespValue& request, PoolCallbacks& callbacks) override; private: struct UpstreamReadFilter : public Network::ReadFilterBaseImpl { @@ -43,13 +43,13 @@ class ClientImpl : public Client, public DecoderCallbacks, public Network::Conne ClientImpl& parent_; }; - struct PendingRequest : public ActiveRequest { - PendingRequest(ActiveRequestCallbacks& callbacks) : callbacks_(callbacks) {} + struct PendingRequest : public PoolRequest { + PendingRequest(PoolCallbacks& callbacks) : callbacks_(callbacks) {} - // Redis::ConnPool::ActiveRequest + // Redis::ConnPool::PoolRequest void cancel() override; - ActiveRequestCallbacks& callbacks_; + PoolCallbacks& callbacks_; bool canceled_{}; }; @@ -88,8 +88,8 @@ class InstanceImpl : public Instance { ClientFactory& client_factory, ThreadLocal::Instance& tls); // Redis::ConnPool::Instance - ActiveRequest* makeRequest(const std::string& hash_key, const RespValue& request, - ActiveRequestCallbacks& callbacks) override; + PoolRequest* makeRequest(const std::string& hash_key, const RespValue& request, + PoolCallbacks& callbacks) override; private: struct ThreadLocalPool; @@ -112,8 +112,8 @@ class InstanceImpl : public Instance { ThreadLocalPool(InstanceImpl& parent, Event::Dispatcher& dispatcher, const std::string& cluster_name); - ActiveRequest* makeRequest(const std::string& hash_key, const RespValue& request, - ActiveRequestCallbacks& callbacks); + PoolRequest* makeRequest(const std::string& hash_key, const RespValue& request, + PoolCallbacks& callbacks); void onHostsRemoved(const std::vector& hosts_removed); // ThreadLocal::ThreadLocalObject diff --git a/source/common/redis/proxy_filter.cc b/source/common/redis/proxy_filter.cc index dd830694344f..f68f7522ac6c 100644 --- a/source/common/redis/proxy_filter.cc +++ b/source/common/redis/proxy_filter.cc @@ -6,10 +6,8 @@ namespace Redis { ProxyFilterConfig::ProxyFilterConfig(const Json::Object& config, Upstream::ClusterManager& cm) - : cluster_name_{config.getString("cluster_name")} { - - config.validateSchema(Json::Schema::REDIS_PROXY_NETWORK_FILTER_SCHEMA); - + : Json::Validator(config, Json::Schema::REDIS_PROXY_NETWORK_FILTER_SCHEMA), + cluster_name_{config.getString("cluster_name")} { if (!cm.get(cluster_name_)) { throw EnvoyException( fmt::format("redis filter config: unknown cluster name '{}'", cluster_name_)); @@ -21,10 +19,11 @@ ProxyFilter::~ProxyFilter() { ASSERT(pending_requests_.empty()); } void ProxyFilter::onRespValue(RespValuePtr&& value) { pending_requests_.emplace_back(*this); PendingRequest& request = pending_requests_.back(); - request.request_handle_ = conn_pool_.makeRequest("", *value, request); - if (!request.request_handle_) { - respondWithFailure("no healthy upstream"); - pending_requests_.pop_back(); + CommandSplitter::SplitRequestPtr split = splitter_.makeRequest(*value, request); + if (split) { + // The splitter can immediately respond and destroy the pending request. Only store the handle + // if the request is still alive. + request.request_handle_ = std::move(split); } } @@ -55,30 +54,19 @@ void ProxyFilter::onResponse(PendingRequest& request, RespValuePtr&& value) { } } -void ProxyFilter::onFailure(PendingRequest& request) { - RespValuePtr error(new RespValue()); - error->type(RespType::Error); - error->asString() = "upstream connection error"; - onResponse(request, std::move(error)); -} - Network::FilterStatus ProxyFilter::onData(Buffer::Instance& data) { try { decoder_->decode(data); return Network::FilterStatus::Continue; } catch (ProtocolError&) { - respondWithFailure("downstream protocol error"); + RespValue error; + error.type(RespType::Error); + error.asString() = "downstream protocol error"; + encoder_->encode(error, encoder_buffer_); + callbacks_->connection().write(encoder_buffer_); callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); return Network::FilterStatus::StopIteration; } } -void ProxyFilter::respondWithFailure(const std::string& message) { - RespValue error; - error.type(RespType::Error); - error.asString() = message; - encoder_->encode(error, encoder_buffer_); - callbacks_->connection().write(encoder_buffer_); -} - } // Redis diff --git a/source/common/redis/proxy_filter.h b/source/common/redis/proxy_filter.h index 1048d0d12277..4c48b60d6c4a 100644 --- a/source/common/redis/proxy_filter.h +++ b/source/common/redis/proxy_filter.h @@ -2,20 +2,21 @@ #include "envoy/network/filter.h" #include "envoy/redis/codec.h" -#include "envoy/redis/conn_pool.h" +#include "envoy/redis/command_splitter.h" +#include "envoy/upstream/cluster_manager.h" #include "common/buffer/buffer_impl.h" #include "common/json/json_loader.h" +#include "common/json/json_validator.h" namespace Redis { // TODO(mattklein123): Stats -// TODO(mattklein123): Actual multiplexing, command verification, and splitting /** * Configuration for the redis proxy filter. */ -class ProxyFilterConfig { +class ProxyFilterConfig : Json::Validator { public: ProxyFilterConfig(const Json::Object& config, Upstream::ClusterManager& cm); @@ -33,8 +34,8 @@ class ProxyFilter : public Network::ReadFilter, public DecoderCallbacks, public Network::ConnectionCallbacks { public: - ProxyFilter(DecoderFactory& factory, EncoderPtr&& encoder, ConnPool::Instance& conn_pool) - : decoder_(factory.create(*this)), encoder_(std::move(encoder)), conn_pool_(conn_pool) {} + ProxyFilter(DecoderFactory& factory, EncoderPtr&& encoder, CommandSplitter::Instance& splitter) + : decoder_(factory.create(*this)), encoder_(std::move(encoder)), splitter_(splitter) {} ~ProxyFilter(); @@ -53,25 +54,22 @@ class ProxyFilter : public Network::ReadFilter, void onRespValue(RespValuePtr&& value) override; private: - struct PendingRequest : public ConnPool::ActiveRequestCallbacks { + struct PendingRequest : public CommandSplitter::SplitCallbacks { PendingRequest(ProxyFilter& parent) : parent_(parent) {} - // Redis::ConnPool::ActiveRequestCallbacks + // Redis::CommandSplitter::SplitCallbacks void onResponse(RespValuePtr&& value) override { parent_.onResponse(*this, std::move(value)); } - void onFailure() override { parent_.onFailure(*this); } ProxyFilter& parent_; RespValuePtr pending_response_; - ConnPool::ActiveRequest* request_handle_; + CommandSplitter::SplitRequestPtr request_handle_; }; void onResponse(PendingRequest& request, RespValuePtr&& value); - void onFailure(PendingRequest& request); - void respondWithFailure(const std::string& message); DecoderPtr decoder_; EncoderPtr encoder_; - ConnPool::Instance& conn_pool_; + CommandSplitter::Instance& splitter_; Buffer::OwnedImpl encoder_buffer_; Network::ReadFilterCallbacks* callbacks_{}; std::list pending_requests_; diff --git a/source/common/router/config_utility.h b/source/common/router/config_utility.h index e60f5a6c2afa..54ddfd9ab77f 100644 --- a/source/common/router/config_utility.h +++ b/source/common/router/config_utility.h @@ -15,12 +15,12 @@ namespace Router { */ class ConfigUtility { public: - struct HeaderData : Json::JsonValidator { + struct HeaderData : Json::Validator { // An empty header value allows for matching to be only based on header presence. // Regex is an opt-in. Unless explicitly mentioned, the header values will be used for // exact string matching. HeaderData(const Json::Object& config) - : Json::JsonValidator(config, Json::Schema::HEADER_DATA_CONFIGURATION_SCHEMA), + : Json::Validator(config, Json::Schema::HEADER_DATA_CONFIGURATION_SCHEMA), name_(config.getString("name")), value_(config.getString("value", EMPTY_STRING)), regex_pattern_(value_, std::regex::optimize), is_regex_(config.getBoolean("regex", false)) {} diff --git a/source/common/router/router_ratelimit.cc b/source/common/router/router_ratelimit.cc index 3c1020c0587c..c10b49bc2c0f 100644 --- a/source/common/router/router_ratelimit.cc +++ b/source/common/router/router_ratelimit.cc @@ -69,7 +69,7 @@ void HeaderValueMatchAction::populateDescriptor(const Router::RouteEntry&, } RateLimitPolicyEntryImpl::RateLimitPolicyEntryImpl(const Json::Object& config) - : Json::JsonValidator(config, Json::Schema::HTTP_RATE_LIMITS_CONFIGURATION_SCHEMA), + : Json::Validator(config, Json::Schema::HTTP_RATE_LIMITS_CONFIGURATION_SCHEMA), disable_key_(config.getString("disable_key", "")), stage_(static_cast(config.getInteger("stage", 0))) { for (const Json::ObjectPtr& action : config.getObjectArray("actions")) { diff --git a/source/common/router/router_ratelimit.h b/source/common/router/router_ratelimit.h index ec3f0cb26e1f..87a700ee0a4c 100644 --- a/source/common/router/router_ratelimit.h +++ b/source/common/router/router_ratelimit.h @@ -98,7 +98,7 @@ class HeaderValueMatchAction : public RateLimitAction { /* * Implementation of RateLimitPolicyEntry that holds the action for the configuration. */ -class RateLimitPolicyEntryImpl : public RateLimitPolicyEntry, Json::JsonValidator { +class RateLimitPolicyEntryImpl : public RateLimitPolicyEntry, Json::Validator { public: RateLimitPolicyEntryImpl(const Json::Object& config); diff --git a/source/server/config/network/redis_proxy.cc b/source/server/config/network/redis_proxy.cc index 65579019006e..10930da400e6 100644 --- a/source/server/config/network/redis_proxy.cc +++ b/source/server/config/network/redis_proxy.cc @@ -1,6 +1,7 @@ #include "redis_proxy.h" #include "common/redis/codec_impl.h" +#include "common/redis/command_splitter_impl.h" #include "common/redis/conn_pool_impl.h" #include "common/redis/proxy_filter.h" @@ -15,13 +16,15 @@ NetworkFilterFactoryCb RedisProxyFilterConfigFactory::tryCreateFilterFactory( } Redis::ProxyFilterConfig filter_config(config, server.clusterManager()); - std::shared_ptr conn_pool(new Redis::ConnPool::InstanceImpl( + Redis::ConnPool::InstancePtr conn_pool(new Redis::ConnPool::InstanceImpl( filter_config.clusterName(), server.clusterManager(), Redis::ConnPool::ClientFactoryImpl::instance_, server.threadLocal())); - return [conn_pool](Network::FilterManager& filter_manager) -> void { + std::shared_ptr splitter( + new Redis::CommandSplitter::InstanceImpl(std::move(conn_pool))); + return [splitter](Network::FilterManager& filter_manager) -> void { Redis::DecoderFactoryImpl factory; filter_manager.addReadFilter(Network::ReadFilterSharedPtr{ - new Redis::ProxyFilter(factory, Redis::EncoderPtr{new Redis::EncoderImpl()}, *conn_pool)}); + new Redis::ProxyFilter(factory, Redis::EncoderPtr{new Redis::EncoderImpl()}, *splitter)}); }; } diff --git a/source/server/server.cc b/source/server/server.cc index cc9cd2a9deb6..f25e17909518 100644 --- a/source/server/server.cc +++ b/source/server/server.cc @@ -361,8 +361,6 @@ void InstanceImpl::run() { Runtime::Loader& InstanceImpl::runtime() { return *runtime_loader_; } -InstanceImpl::~InstanceImpl() {} - void InstanceImpl::shutdown() { log().warn("shutdown invoked. sending SIGTERM to self"); kill(getpid(), SIGTERM); diff --git a/source/server/server.h b/source/server/server.h index d2d49c1560ac..d9972fdbd118 100644 --- a/source/server/server.h +++ b/source/server/server.h @@ -94,7 +94,6 @@ class InstanceImpl : Logger::Loggable, public Instance { InstanceImpl(Options& options, TestHooks& hooks, HotRestart& restarter, Stats::StoreRoot& store, Thread::BasicLockable& access_log_lock, ComponentFactory& component_factory, const LocalInfo::LocalInfo& local_info); - ~InstanceImpl(); void run(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 966f39067f29..7484d813a5a4 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -82,6 +82,7 @@ add_executable(envoy-test common/network/utility_test.cc common/ratelimit/ratelimit_impl_test.cc common/redis/codec_impl_test.cc + common/redis/command_splitter_impl_test.cc common/redis/conn_pool_impl_test.cc common/redis/proxy_filter_test.cc common/router/config_impl_test.cc diff --git a/test/common/redis/codec_impl_test.cc b/test/common/redis/codec_impl_test.cc index 939c8700aae0..0e835aeaac60 100644 --- a/test/common/redis/codec_impl_test.cc +++ b/test/common/redis/codec_impl_test.cc @@ -25,6 +25,7 @@ class RedisEncoderDecoderImplTest : public testing::Test, public DecoderCallback TEST_F(RedisEncoderDecoderImplTest, Null) { RespValue value; + EXPECT_EQ("null", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("$-1\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -36,6 +37,7 @@ TEST_F(RedisEncoderDecoderImplTest, Error) { RespValue value; value.type(RespType::Error); value.asString() = "error"; + EXPECT_EQ("\"error\"", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("-error\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -47,6 +49,7 @@ TEST_F(RedisEncoderDecoderImplTest, SimpleString) { RespValue value; value.type(RespType::SimpleString); value.asString() = "simple string"; + EXPECT_EQ("\"simple string\"", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("+simple string\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -58,6 +61,7 @@ TEST_F(RedisEncoderDecoderImplTest, Integer) { RespValue value; value.type(RespType::Integer); value.asInteger() = std::numeric_limits::max(); + EXPECT_EQ("9223372036854775807", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ(":9223372036854775807\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -79,6 +83,7 @@ TEST_F(RedisEncoderDecoderImplTest, NegativeInteger) { TEST_F(RedisEncoderDecoderImplTest, EmptyArray) { RespValue value; value.type(RespType::Array); + EXPECT_EQ("[]", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("*0\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -96,6 +101,7 @@ TEST_F(RedisEncoderDecoderImplTest, Array) { RespValue value; value.type(RespType::Array); value.asArray().swap(values); + EXPECT_EQ("[\"hello\", -5]", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("*2\r\n$5\r\nhello\r\n:-5\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); diff --git a/test/common/redis/command_splitter_impl_test.cc b/test/common/redis/command_splitter_impl_test.cc new file mode 100644 index 000000000000..74841023921e --- /dev/null +++ b/test/common/redis/command_splitter_impl_test.cc @@ -0,0 +1,333 @@ +#include "common/redis/command_splitter_impl.h" + +#include "test/mocks/common.h" +#include "test/mocks/redis/mocks.h" + +using testing::_; +using testing::ByRef; +using testing::DoAll; +using testing::Eq; +using testing::InSequence; +using testing::Ref; +using testing::Return; +using testing::WithArg; + +namespace Redis { +namespace CommandSplitter { + +class RedisCommandSplitterImplTest : public testing::Test { +public: + void makeBulkStringArray(RespValue& value, const std::vector& strings) { + std::vector values(strings.size()); + for (uint64_t i = 0; i < strings.size(); i++) { + values[i].type(RespType::BulkString); + values[i].asString() = strings[i]; + } + + value.type(RespType::Array); + value.asArray().swap(values); + } + + ConnPool::MockInstance* conn_pool_{new ConnPool::MockInstance()}; + InstanceImpl splitter_{ConnPool::InstancePtr{conn_pool_}}; + MockSplitCallbacks callbacks_; + SplitRequestPtr handle_; +}; + +TEST_F(RedisCommandSplitterImplTest, InvalidRequestNotArray) { + RespValue response; + response.type(RespType::Error); + response.asString() = "invalid request"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + RespValue request; + EXPECT_EQ(nullptr, splitter_.makeRequest(request, callbacks_)); +} + +TEST_F(RedisCommandSplitterImplTest, InvalidRequestArrayTooSmall) { + RespValue response; + response.type(RespType::Error); + response.asString() = "invalid request"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + RespValue request; + makeBulkStringArray(request, {"incr"}); + EXPECT_EQ(nullptr, splitter_.makeRequest(request, callbacks_)); +} + +TEST_F(RedisCommandSplitterImplTest, InvalidRequestArrayNotStrings) { + RespValue response; + response.type(RespType::Error); + response.asString() = "invalid request"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + RespValue request; + makeBulkStringArray(request, {"incr", ""}); + request.asArray()[1].type(RespType::Null); + EXPECT_EQ(nullptr, splitter_.makeRequest(request, callbacks_)); +} + +TEST_F(RedisCommandSplitterImplTest, UnsupportedCommand) { + RespValue response; + response.type(RespType::Error); + response.asString() = "unsupported command 'newcommand'"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + RespValue request; + makeBulkStringArray(request, {"newcommand", "hello"}); + EXPECT_EQ(nullptr, splitter_.makeRequest(request, callbacks_)); +} + +class RedisAllParamsToOneServerCommandHandlerTest : public RedisCommandSplitterImplTest { +public: + void makeRequest(const std::string& hash_key, const RespValue& request) { + EXPECT_CALL(*conn_pool_, makeRequest(hash_key, Ref(request), _)) + .WillOnce(DoAll(WithArg<2>(SaveArgAddress(&pool_callbacks_)), Return(&pool_request_))); + handle_ = splitter_.makeRequest(request, callbacks_); + } + + void fail() { + RespValue response; + response.type(RespType::Error); + response.asString() = "upstream failure"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + pool_callbacks_->onFailure(); + } + + void respond() { + RespValuePtr response1(new RespValue()); + RespValue* response1_ptr = response1.get(); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(response1_ptr))); + pool_callbacks_->onResponse(std::move(response1)); + } + + ConnPool::PoolCallbacks* pool_callbacks_; + ConnPool::MockPoolRequest pool_request_; +}; + +TEST_F(RedisAllParamsToOneServerCommandHandlerTest, IncrSuccess) { + InSequence s; + + RespValue request; + makeBulkStringArray(request, {"incr", "hello"}); + makeRequest("hello", request); + EXPECT_NE(nullptr, handle_); + + respond(); +}; + +TEST_F(RedisAllParamsToOneServerCommandHandlerTest, IncrFail) { + InSequence s; + + RespValue request; + makeBulkStringArray(request, {"incr", "hello"}); + makeRequest("hello", request); + EXPECT_NE(nullptr, handle_); + + fail(); +}; + +TEST_F(RedisAllParamsToOneServerCommandHandlerTest, IncrCancel) { + InSequence s; + + RespValue request; + makeBulkStringArray(request, {"incr", "hello"}); + makeRequest("hello", request); + EXPECT_NE(nullptr, handle_); + + EXPECT_CALL(pool_request_, cancel()); + handle_->cancel(); +}; + +TEST_F(RedisAllParamsToOneServerCommandHandlerTest, IncrNoUpstream) { + InSequence s; + + RespValue request; + makeBulkStringArray(request, {"incr", "hello"}); + EXPECT_CALL(*conn_pool_, makeRequest("hello", Ref(request), _)).WillOnce(Return(nullptr)); + RespValue response; + response.type(RespType::Error); + response.asString() = "no upstream host"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + handle_ = splitter_.makeRequest(request, callbacks_); + EXPECT_EQ(nullptr, handle_); +}; + +class RedisMGETCommandHandlerTest : public RedisCommandSplitterImplTest { +public: + void setup(uint32_t num_gets, const std::list& null_handle_indexes) { + std::vector request_strings = {"mget"}; + for (uint32_t i = 0; i < num_gets; i++) { + request_strings.push_back(std::to_string(i)); + } + + RespValue request; + makeBulkStringArray(request, request_strings); + + std::vector tmp_expected_requests(num_gets); + expected_requests_.swap(tmp_expected_requests); + pool_callbacks_.resize(num_gets); + std::vector tmp_pool_requests(num_gets); + pool_requests_.swap(tmp_pool_requests); + for (uint32_t i = 0; i < num_gets; i++) { + makeBulkStringArray(expected_requests_[i], {"get", std::to_string(i)}); + ConnPool::PoolRequest* request_to_use = nullptr; + if (std::find(null_handle_indexes.begin(), null_handle_indexes.end(), i) == + null_handle_indexes.end()) { + request_to_use = &pool_requests_[i]; + } + EXPECT_CALL(*conn_pool_, makeRequest(std::to_string(i), Eq(ByRef(expected_requests_[i])), _)) + .WillOnce(DoAll(WithArg<2>(SaveArgAddress(&pool_callbacks_[i])), Return(request_to_use))); + } + + handle_ = splitter_.makeRequest(request, callbacks_); + } + + std::vector expected_requests_; + std::vector pool_callbacks_; + std::vector pool_requests_; +}; + +TEST_F(RedisMGETCommandHandlerTest, Normal) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::BulkString); + elements[0].asString() = "response"; + elements[1].type(RespType::BulkString); + elements[1].asString() = "5"; + expected_response.asArray().swap(elements); + + RespValuePtr response2(new RespValue()); + response2->type(RespType::BulkString); + response2->asString() = "5"; + pool_callbacks_[1]->onResponse(std::move(response2)); + + RespValuePtr response1(new RespValue()); + response1->type(RespType::BulkString); + response1->asString() = "response"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(std::move(response1)); +}; + +TEST_F(RedisMGETCommandHandlerTest, NormalWithNull) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::BulkString); + elements[0].asString() = "response"; + expected_response.asArray().swap(elements); + + RespValuePtr response2(new RespValue()); + pool_callbacks_[1]->onResponse(std::move(response2)); + + RespValuePtr response1(new RespValue()); + response1->type(RespType::BulkString); + response1->asString() = "response"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(std::move(response1)); +}; + +TEST_F(RedisMGETCommandHandlerTest, NoUpstreamHostForAll) { + // No InSequence to avoid making setup() more complicated. + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::Error); + elements[0].asString() = "no upstream host"; + elements[1].type(RespType::Error); + elements[1].asString() = "no upstream host"; + expected_response.asArray().swap(elements); + + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + setup(2, {0, 1}); + EXPECT_EQ(nullptr, handle_); +}; + +TEST_F(RedisMGETCommandHandlerTest, NoUpstreamHostForOne) { + InSequence s; + + setup(2, {0}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::Error); + elements[0].asString() = "no upstream host"; + elements[1].type(RespType::Error); + elements[1].asString() = "upstream failure"; + expected_response.asArray().swap(elements); + + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[1]->onFailure(); +}; + +TEST_F(RedisMGETCommandHandlerTest, Failure) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::BulkString); + elements[0].asString() = "response"; + elements[1].type(RespType::Error); + elements[1].asString() = "upstream failure"; + expected_response.asArray().swap(elements); + + pool_callbacks_[1]->onFailure(); + + RespValuePtr response1(new RespValue()); + response1->type(RespType::BulkString); + response1->asString() = "response"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(std::move(response1)); +}; + +TEST_F(RedisMGETCommandHandlerTest, InvalidUpstreamResponse) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::Error); + elements[0].asString() = "upstream protocol error"; + elements[1].type(RespType::Error); + elements[1].asString() = "upstream failure"; + expected_response.asArray().swap(elements); + + pool_callbacks_[1]->onFailure(); + + RespValuePtr response1(new RespValue()); + response1->type(RespType::Integer); + response1->asInteger() = 5; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(std::move(response1)); +}; + +TEST_F(RedisMGETCommandHandlerTest, Cancel) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + EXPECT_CALL(pool_requests_[0], cancel()); + EXPECT_CALL(pool_requests_[1], cancel()); + handle_->cancel(); +}; + +} // CommandSplitter +} // Redis diff --git a/test/common/redis/conn_pool_impl_test.cc b/test/common/redis/conn_pool_impl_test.cc index 9a0351b996f5..ad16d12a9522 100644 --- a/test/common/redis/conn_pool_impl_test.cc +++ b/test/common/redis/conn_pool_impl_test.cc @@ -53,15 +53,15 @@ TEST_F(RedisClientImplTest, Basic) { setup(); RespValue request1; - MockActiveRequestCallbacks callbacks1; + MockPoolCallbacks callbacks1; EXPECT_CALL(*encoder_, encode(Ref(request1), _)); - ActiveRequest* handle1 = client_->makeRequest(request1, callbacks1); + PoolRequest* handle1 = client_->makeRequest(request1, callbacks1); EXPECT_NE(nullptr, handle1); RespValue request2; - MockActiveRequestCallbacks callbacks2; + MockPoolCallbacks callbacks2; EXPECT_CALL(*encoder_, encode(Ref(request2), _)); - ActiveRequest* handle2 = client_->makeRequest(request2, callbacks2); + PoolRequest* handle2 = client_->makeRequest(request2, callbacks2); EXPECT_NE(nullptr, handle2); Buffer::OwnedImpl fake_data; @@ -86,15 +86,15 @@ TEST_F(RedisClientImplTest, Cancel) { setup(); RespValue request1; - MockActiveRequestCallbacks callbacks1; + MockPoolCallbacks callbacks1; EXPECT_CALL(*encoder_, encode(Ref(request1), _)); - ActiveRequest* handle1 = client_->makeRequest(request1, callbacks1); + PoolRequest* handle1 = client_->makeRequest(request1, callbacks1); EXPECT_NE(nullptr, handle1); RespValue request2; - MockActiveRequestCallbacks callbacks2; + MockPoolCallbacks callbacks2; EXPECT_CALL(*encoder_, encode(Ref(request2), _)); - ActiveRequest* handle2 = client_->makeRequest(request2, callbacks2); + PoolRequest* handle2 = client_->makeRequest(request2, callbacks2); EXPECT_NE(nullptr, handle2); handle1->cancel(); @@ -124,9 +124,9 @@ TEST_F(RedisClientImplTest, FailAll) { client_->addConnectionCallbacks(connection_callbacks); RespValue request1; - MockActiveRequestCallbacks callbacks1; + MockPoolCallbacks callbacks1; EXPECT_CALL(*encoder_, encode(Ref(request1), _)); - ActiveRequest* handle1 = client_->makeRequest(request1, callbacks1); + PoolRequest* handle1 = client_->makeRequest(request1, callbacks1); EXPECT_NE(nullptr, handle1); EXPECT_CALL(connection_callbacks, onEvent(Network::ConnectionEvent::RemoteClose)); @@ -138,9 +138,9 @@ TEST_F(RedisClientImplTest, ProtocolError) { setup(); RespValue request1; - MockActiveRequestCallbacks callbacks1; + MockPoolCallbacks callbacks1; EXPECT_CALL(*encoder_, encode(Ref(request1), _)); - ActiveRequest* handle1 = client_->makeRequest(request1, callbacks1); + PoolRequest* handle1 = client_->makeRequest(request1, callbacks1); EXPECT_NE(nullptr, handle1); Buffer::OwnedImpl fake_data; @@ -181,8 +181,8 @@ TEST_F(RedisConnPoolImplTest, Basic) { InSequence s; RespValue value; - MockActiveRequest active_request; - MockActiveRequestCallbacks callbacks; + MockPoolRequest active_request; + MockPoolCallbacks callbacks; MockClient* client = new NiceMock(); EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)) @@ -193,7 +193,7 @@ TEST_F(RedisConnPoolImplTest, Basic) { })); EXPECT_CALL(*this, create_(_, _)).WillOnce(Return(client)); EXPECT_CALL(*client, makeRequest(Ref(value), Ref(callbacks))).WillOnce(Return(&active_request)); - ActiveRequest* request = conn_pool_.makeRequest("foo", value, callbacks); + PoolRequest* request = conn_pool_.makeRequest("foo", value, callbacks); EXPECT_EQ(&active_request, request); EXPECT_CALL(*client, close()); @@ -202,7 +202,7 @@ TEST_F(RedisConnPoolImplTest, Basic) { TEST_F(RedisConnPoolImplTest, HostRemove) { InSequence s; - MockActiveRequestCallbacks callbacks; + MockPoolCallbacks callbacks; RespValue value; std::shared_ptr host1(new Upstream::MockHost()); @@ -213,17 +213,17 @@ TEST_F(RedisConnPoolImplTest, HostRemove) { EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)).WillOnce(Return(host1)); EXPECT_CALL(*this, create_(Eq(host1), _)).WillOnce(Return(client1)); - MockActiveRequest active_request1; + MockPoolRequest active_request1; EXPECT_CALL(*client1, makeRequest(Ref(value), Ref(callbacks))).WillOnce(Return(&active_request1)); - ActiveRequest* request1 = conn_pool_.makeRequest("foo", value, callbacks); + PoolRequest* request1 = conn_pool_.makeRequest("foo", value, callbacks); EXPECT_EQ(&active_request1, request1); EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)).WillOnce(Return(host2)); EXPECT_CALL(*this, create_(Eq(host2), _)).WillOnce(Return(client2)); - MockActiveRequest active_request2; + MockPoolRequest active_request2; EXPECT_CALL(*client2, makeRequest(Ref(value), Ref(callbacks))).WillOnce(Return(&active_request2)); - ActiveRequest* request2 = conn_pool_.makeRequest("bar", value, callbacks); + PoolRequest* request2 = conn_pool_.makeRequest("bar", value, callbacks); EXPECT_EQ(&active_request2, request2); EXPECT_CALL(*client2, close()); @@ -237,9 +237,9 @@ TEST_F(RedisConnPoolImplTest, NoHost) { InSequence s; RespValue value; - MockActiveRequestCallbacks callbacks; + MockPoolCallbacks callbacks; EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)).WillOnce(Return(nullptr)); - ActiveRequest* request = conn_pool_.makeRequest("foo", value, callbacks); + PoolRequest* request = conn_pool_.makeRequest("foo", value, callbacks); EXPECT_EQ(nullptr, request); tls_.shutdownThread(); @@ -249,8 +249,8 @@ TEST_F(RedisConnPoolImplTest, RemoteClose) { InSequence s; RespValue value; - MockActiveRequest active_request; - MockActiveRequestCallbacks callbacks; + MockPoolRequest active_request; + MockPoolCallbacks callbacks; MockClient* client = new NiceMock(); EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)); diff --git a/test/common/redis/proxy_filter_test.cc b/test/common/redis/proxy_filter_test.cc index c59c69f4375d..875caae6550f 100644 --- a/test/common/redis/proxy_filter_test.cc +++ b/test/common/redis/proxy_filter_test.cc @@ -73,8 +73,8 @@ class RedisProxyFilterTest : public testing::Test, public DecoderFactory { MockEncoder* encoder_{new MockEncoder()}; MockDecoder* decoder_{new MockDecoder()}; DecoderCallbacks* decoder_callbacks_{}; - ConnPool::MockInstance conn_pool_; - ProxyFilter filter_{*this, EncoderPtr{encoder_}, conn_pool_}; + CommandSplitter::MockInstance splitter_; + ProxyFilter filter_{*this, EncoderPtr{encoder_}, splitter_}; NiceMock filter_callbacks_; }; @@ -82,22 +82,22 @@ TEST_F(RedisProxyFilterTest, OutOfOrderResponse) { InSequence s; Buffer::OwnedImpl fake_data; - ConnPool::MockActiveRequest request_handle1; - ConnPool::ActiveRequestCallbacks* request_callbacks1; - ConnPool::MockActiveRequest request_handle2; - ConnPool::ActiveRequestCallbacks* request_callbacks2; + CommandSplitter::MockSplitRequest* request_handle1 = new CommandSplitter::MockSplitRequest(); + CommandSplitter::SplitCallbacks* request_callbacks1; + CommandSplitter::MockSplitRequest* request_handle2 = new CommandSplitter::MockSplitRequest(); + CommandSplitter::SplitCallbacks* request_callbacks2; EXPECT_CALL(*decoder_, decode(Ref(fake_data))) .WillOnce(Invoke([&](Buffer::Instance&) -> void { RespValuePtr request1(new RespValue()); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request1), _)) + EXPECT_CALL(splitter_, makeRequest_(Ref(*request1), _)) .WillOnce( - DoAll(WithArg<2>(SaveArgAddress(&request_callbacks1)), Return(&request_handle1))); + DoAll(WithArg<1>(SaveArgAddress(&request_callbacks1)), Return(request_handle1))); decoder_callbacks_->onRespValue(std::move(request1)); RespValuePtr request2(new RespValue()); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request2), _)) + EXPECT_CALL(splitter_, makeRequest_(Ref(*request2), _)) .WillOnce( - DoAll(WithArg<2>(SaveArgAddress(&request_callbacks2)), Return(&request_handle2))); + DoAll(WithArg<1>(SaveArgAddress(&request_callbacks2)), Return(request_handle2))); decoder_callbacks_->onRespValue(std::move(request2)); })); EXPECT_EQ(Network::FilterStatus::Continue, filter_.onData(fake_data)); @@ -115,53 +115,27 @@ TEST_F(RedisProxyFilterTest, OutOfOrderResponse) { filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); } -TEST_F(RedisProxyFilterTest, UpstreamFailure) { - InSequence s; - - Buffer::OwnedImpl fake_data; - ConnPool::MockActiveRequest request_handle1; - ConnPool::ActiveRequestCallbacks* request_callbacks1; - EXPECT_CALL(*decoder_, decode(Ref(fake_data))) - .WillOnce(Invoke([&](Buffer::Instance&) -> void { - RespValuePtr request1(new RespValue()); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request1), _)) - .WillOnce( - DoAll(WithArg<2>(SaveArgAddress(&request_callbacks1)), Return(&request_handle1))); - decoder_callbacks_->onRespValue(std::move(request1)); - })); - EXPECT_EQ(Network::FilterStatus::Continue, filter_.onData(fake_data)); - - RespValue error; - error.type(RespType::Error); - error.asString() = "upstream connection error"; - EXPECT_CALL(*encoder_, encode(Eq(ByRef(error)), _)); - EXPECT_CALL(filter_callbacks_.connection_, write(_)); - request_callbacks1->onFailure(); - - filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::LocalClose); -} - TEST_F(RedisProxyFilterTest, DownstreamDisconnectWithActive) { InSequence s; Buffer::OwnedImpl fake_data; - ConnPool::MockActiveRequest request_handle1; - ConnPool::ActiveRequestCallbacks* request_callbacks1; + CommandSplitter::MockSplitRequest* request_handle1 = new CommandSplitter::MockSplitRequest(); + CommandSplitter::SplitCallbacks* request_callbacks1; EXPECT_CALL(*decoder_, decode(Ref(fake_data))) .WillOnce(Invoke([&](Buffer::Instance&) -> void { RespValuePtr request1(new RespValue()); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request1), _)) + EXPECT_CALL(splitter_, makeRequest_(Ref(*request1), _)) .WillOnce( - DoAll(WithArg<2>(SaveArgAddress(&request_callbacks1)), Return(&request_handle1))); + DoAll(WithArg<1>(SaveArgAddress(&request_callbacks1)), Return(request_handle1))); decoder_callbacks_->onRespValue(std::move(request1)); })); EXPECT_EQ(Network::FilterStatus::Continue, filter_.onData(fake_data)); - EXPECT_CALL(request_handle1, cancel()); + EXPECT_CALL(*request_handle1, cancel()); filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); } -TEST_F(RedisProxyFilterTest, NoClient) { +TEST_F(RedisProxyFilterTest, ImmediateResponse) { InSequence s; Buffer::OwnedImpl fake_data; @@ -169,15 +143,19 @@ TEST_F(RedisProxyFilterTest, NoClient) { EXPECT_CALL(*decoder_, decode(Ref(fake_data))) .WillOnce(Invoke([&](Buffer::Instance&) -> void { decoder_callbacks_->onRespValue(std::move(request1)); })); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request1), _)).WillOnce(Return(nullptr)); + EXPECT_CALL(splitter_, makeRequest_(Ref(*request1), _)) + .WillOnce(Invoke([&](const RespValue&, CommandSplitter::SplitCallbacks& callbacks) + -> CommandSplitter::SplitRequest* { + RespValuePtr error(new RespValue()); + error->type(RespType::Error); + error->asString() = "no healthy upstream"; + EXPECT_CALL(*encoder_, encode(Eq(ByRef(*error)), _)); + EXPECT_CALL(filter_callbacks_.connection_, write(_)); + callbacks.onResponse(std::move(error)); + return nullptr; + })); - RespValue error; - error.type(RespType::Error); - error.asString() = "no healthy upstream"; - EXPECT_CALL(*encoder_, encode(Eq(ByRef(error)), _)); - EXPECT_CALL(filter_callbacks_.connection_, write(_)); EXPECT_EQ(Network::FilterStatus::Continue, filter_.onData(fake_data)); - filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); } diff --git a/test/mocks/common.h b/test/mocks/common.h index 27a59c92dcde..8076be729085 100644 --- a/test/mocks/common.h +++ b/test/mocks/common.h @@ -12,7 +12,10 @@ ACTION_P(SaveArgAddress, target) { *target = &arg0; } /** * Matcher that matches on whether the pointee of both lhs and rhs are equal. */ -MATCHER_P(PointeesEq, rhs, "") { return *arg == *rhs; } +MATCHER_P(PointeesEq, rhs, "") { + *result_listener << testing::PrintToString(*arg) + " != " + testing::PrintToString(*rhs); + return *arg == *rhs; +} /** * Simple mock that just lets us make sure a method gets called or not called form a lambda. diff --git a/test/mocks/redis/mocks.cc b/test/mocks/redis/mocks.cc index 6ffb7072be7d..582dd5013ab1 100644 --- a/test/mocks/redis/mocks.cc +++ b/test/mocks/redis/mocks.cc @@ -65,14 +65,27 @@ MockClient::MockClient() { MockClient::~MockClient() {} -MockActiveRequest::MockActiveRequest() {} -MockActiveRequest::~MockActiveRequest() {} +MockPoolRequest::MockPoolRequest() {} +MockPoolRequest::~MockPoolRequest() {} -MockActiveRequestCallbacks::MockActiveRequestCallbacks() {} -MockActiveRequestCallbacks::~MockActiveRequestCallbacks() {} +MockPoolCallbacks::MockPoolCallbacks() {} +MockPoolCallbacks::~MockPoolCallbacks() {} MockInstance::MockInstance() {} MockInstance::~MockInstance() {} } // ConnPool + +namespace CommandSplitter { + +MockSplitRequest::MockSplitRequest() {} +MockSplitRequest::~MockSplitRequest() {} + +MockSplitCallbacks::MockSplitCallbacks() {} +MockSplitCallbacks::~MockSplitCallbacks() {} + +MockInstance::MockInstance() {} +MockInstance::~MockInstance() {} + +} // CommandSplitter } // Redis diff --git a/test/mocks/redis/mocks.h b/test/mocks/redis/mocks.h index bf0cd3e5559f..e0da90d1c93d 100644 --- a/test/mocks/redis/mocks.h +++ b/test/mocks/redis/mocks.h @@ -1,5 +1,6 @@ #pragma once +#include "envoy/redis/command_splitter.h" #include "envoy/redis/conn_pool.h" #include "common/redis/codec_impl.h" @@ -42,24 +43,23 @@ class MockClient : public Client { MOCK_METHOD1(addConnectionCallbacks, void(Network::ConnectionCallbacks& callbacks)); MOCK_METHOD0(close, void()); - MOCK_METHOD2(makeRequest, - ActiveRequest*(const RespValue& request, ActiveRequestCallbacks& callbacks)); + MOCK_METHOD2(makeRequest, PoolRequest*(const RespValue& request, PoolCallbacks& callbacks)); std::list callbacks_; }; -class MockActiveRequest : public ActiveRequest { +class MockPoolRequest : public PoolRequest { public: - MockActiveRequest(); - ~MockActiveRequest(); + MockPoolRequest(); + ~MockPoolRequest(); MOCK_METHOD0(cancel, void()); }; -class MockActiveRequestCallbacks : public ActiveRequestCallbacks { +class MockPoolCallbacks : public PoolCallbacks { public: - MockActiveRequestCallbacks(); - ~MockActiveRequestCallbacks(); + MockPoolCallbacks(); + ~MockPoolCallbacks(); void onResponse(RespValuePtr&& value) override { onResponse_(value); } @@ -72,9 +72,43 @@ class MockInstance : public Instance { MockInstance(); ~MockInstance(); - MOCK_METHOD3(makeRequest, ActiveRequest*(const std::string& hash_key, const RespValue& request, - ActiveRequestCallbacks& callbacks)); + MOCK_METHOD3(makeRequest, PoolRequest*(const std::string& hash_key, const RespValue& request, + PoolCallbacks& callbacks)); }; } // ConnPool + +namespace CommandSplitter { + +class MockSplitRequest : public SplitRequest { +public: + MockSplitRequest(); + ~MockSplitRequest(); + + MOCK_METHOD0(cancel, void()); +}; + +class MockSplitCallbacks : public SplitCallbacks { +public: + MockSplitCallbacks(); + ~MockSplitCallbacks(); + + void onResponse(RespValuePtr&& value) override { onResponse_(value); } + + MOCK_METHOD1(onResponse_, void(RespValuePtr& value)); +}; + +class MockInstance : public Instance { +public: + MockInstance(); + ~MockInstance(); + + SplitRequestPtr makeRequest(const RespValue& request, SplitCallbacks& callbacks) override { + return SplitRequestPtr{makeRequest_(request, callbacks)}; + } + + MOCK_METHOD2(makeRequest_, SplitRequest*(const RespValue& request, SplitCallbacks& callbacks)); +}; + +} // CommandSplitter } // Redis diff --git a/test/test_common/printers.cc b/test/test_common/printers.cc index b0ec10c69832..ef0308efbe19 100644 --- a/test/test_common/printers.cc +++ b/test/test_common/printers.cc @@ -1,5 +1,7 @@ #include "printers.h" +#include "envoy/redis/codec.h" + #include "common/buffer/buffer_impl.h" #include "common/http/header_map_impl.h" @@ -18,7 +20,7 @@ void PrintTo(const HeaderMapPtr& headers, std::ostream* os) { void PrintTo(const HeaderMap& headers, std::ostream* os) { PrintTo(*dynamic_cast(&headers), os); } -} +} // Http namespace Buffer { void PrintTo(const Instance& buffer, std::ostream* os) { @@ -28,4 +30,10 @@ void PrintTo(const Instance& buffer, std::ostream* os) { void PrintTo(const Buffer::OwnedImpl& buffer, std::ostream* os) { PrintTo(dynamic_cast(buffer), os); } -} +} // Buffer + +namespace Redis { +void PrintTo(const RespValue& value, std::ostream* os) { *os << value.toString(); } + +void PrintTo(const RespValuePtr& value, std::ostream* os) { *os << value->toString(); } +} // Redis diff --git a/test/test_common/printers.h b/test/test_common/printers.h index a56db2b1e4e9..ba62cd170d34 100644 --- a/test/test_common/printers.h +++ b/test/test_common/printers.h @@ -14,7 +14,7 @@ class HeaderMap; typedef std::unique_ptr HeaderMapPtr; void PrintTo(const HeaderMap& headers, std::ostream* os); void PrintTo(const HeaderMapPtr& headers, std::ostream* os); -} +} // Http namespace Buffer { /** @@ -28,4 +28,14 @@ void PrintTo(const Instance& buffer, std::ostream* os); */ class OwnedImpl; void PrintTo(const OwnedImpl& buffer, std::ostream* os); -} +} // Buffer + +namespace Redis { +/** + * Pretty print const RespValue& value + */ +class RespValue; +typedef std::unique_ptr RespValuePtr; +void PrintTo(const RespValue& value, std::ostream* os); +void PrintTo(const RespValuePtr& value, std::ostream* os); +} // Redis