diff --git a/docs/build-from-source.md b/docs/build-from-source.md index bde2d993a6e8..5dbf590397ef 100644 --- a/docs/build-from-source.md +++ b/docs/build-from-source.md @@ -17,7 +17,7 @@ git clone --recursive https://github.com/dragonflydb/dragonfly && cd dragonfly ```bash # Install dependencies sudo apt install ninja-build libunwind-dev libboost-fiber-dev libssl-dev \ - autoconf-archive libtool cmake g++ + autoconf-archive libtool cmake g++ libzstd-dev ``` ## Step 3 diff --git a/src/redis/redis_aux.c b/src/redis/redis_aux.c index 680961099e1b..db3c922990ed 100644 --- a/src/redis/redis_aux.c +++ b/src/redis/redis_aux.c @@ -29,11 +29,9 @@ void InitRedisTables() { server.hash_max_listpack_entries = 512; server.hash_max_listpack_value = 32; // decreased from redis default 64. - server.rdb_compression = 1; - server.stream_node_max_bytes = 4096; server.stream_node_max_entries = 100; - } +} // These functions are moved here from server.c int htNeedsResize(dict* dict) { @@ -73,7 +71,7 @@ int dictPtrKeyCompare(dict* privdata, const void* key1, const void* key2) { return key1 == key2; } -int dictSdsKeyCompare(dict *d, const void* key1, const void* key2) { +int dictSdsKeyCompare(dict* d, const void* key1, const void* key2) { int l1, l2; DICT_NOTUSED(d); @@ -84,7 +82,7 @@ int dictSdsKeyCompare(dict *d, const void* key1, const void* key2) { return memcmp(key1, key2, l1) == 0; } -void dictSdsDestructor(dict *d, void* val) { +void dictSdsDestructor(dict* d, void* val) { DICT_NOTUSED(d); sdsfree(val); @@ -100,29 +98,28 @@ size_t sdsZmallocSize(sds s) { /* Toggle the 64 bit unsigned integer pointed by *p from little endian to * big endian */ -void memrev64(void *p) { - unsigned char *x = p, t; - - t = x[0]; - x[0] = x[7]; - x[7] = t; - t = x[1]; - x[1] = x[6]; - x[6] = t; - t = x[2]; - x[2] = x[5]; - x[5] = t; - t = x[3]; - x[3] = x[4]; - x[4] = t; +void memrev64(void* p) { + unsigned char *x = p, t; + + t = x[0]; + x[0] = x[7]; + x[7] = t; + t = x[1]; + x[1] = x[6]; + x[6] = t; + t = x[2]; + x[2] = x[5]; + x[5] = t; + t = x[3]; + x[3] = x[4]; + x[4] = t; } uint64_t intrev64(uint64_t v) { - memrev64(&v); - return v; + memrev64(&v); + return v; } - /* Set dictionary type. Keys are SDS strings, values are not used. */ dictType setDictType = { dictSdsHash, /* hash function */ @@ -147,11 +144,11 @@ dictType zsetDictType = { /* Hash type hash table (note that small hashes are represented with listpacks) */ dictType hashDictType = { - dictSdsHash, /* hash function */ - NULL, /* key dup */ - NULL, /* val dup */ - dictSdsKeyCompare, /* key compare */ - dictSdsDestructor, /* key destructor */ - dictSdsDestructor, /* val destructor */ - NULL /* allow to expand */ + dictSdsHash, /* hash function */ + NULL, /* key dup */ + NULL, /* val dup */ + dictSdsKeyCompare, /* key compare */ + dictSdsDestructor, /* key destructor */ + dictSdsDestructor, /* val destructor */ + NULL /* allow to expand */ }; diff --git a/src/redis/redis_aux.h b/src/redis/redis_aux.h index 9d794f99485c..44a20327490c 100644 --- a/src/redis/redis_aux.h +++ b/src/redis/redis_aux.h @@ -4,35 +4,34 @@ #include "dict.h" #include "sds.h" -#define HASHTABLE_MIN_FILL 10 /* Minimal hash table fill 10% */ -#define HASHTABLE_MAX_LOAD_FACTOR 1.618 /* Maximum hash table load factor. */ +#define HASHTABLE_MIN_FILL 10 /* Minimal hash table fill 10% */ +#define HASHTABLE_MAX_LOAD_FACTOR 1.618 /* Maximum hash table load factor. */ /* Redis maxmemory strategies. Instead of using just incremental number * for this defines, we use a set of flags so that testing for certain * properties common to multiple policies is faster. */ -#define MAXMEMORY_FLAG_LRU (1<<0) -#define MAXMEMORY_FLAG_LFU (1<<1) -#define MAXMEMORY_FLAG_ALLKEYS (1<<2) -#define MAXMEMORY_FLAG_NO_SHARED_INTEGERS (MAXMEMORY_FLAG_LRU|MAXMEMORY_FLAG_LFU) +#define MAXMEMORY_FLAG_LRU (1 << 0) +#define MAXMEMORY_FLAG_LFU (1 << 1) +#define MAXMEMORY_FLAG_ALLKEYS (1 << 2) +#define MAXMEMORY_FLAG_NO_SHARED_INTEGERS (MAXMEMORY_FLAG_LRU | MAXMEMORY_FLAG_LFU) #define LFU_INIT_VAL 5 -#define MAXMEMORY_VOLATILE_LRU ((0<<8)|MAXMEMORY_FLAG_LRU) -#define MAXMEMORY_VOLATILE_LFU ((1<<8)|MAXMEMORY_FLAG_LFU) -#define MAXMEMORY_VOLATILE_TTL (2<<8) -#define MAXMEMORY_VOLATILE_RANDOM (3<<8) -#define MAXMEMORY_ALLKEYS_LRU ((4<<8)|MAXMEMORY_FLAG_LRU|MAXMEMORY_FLAG_ALLKEYS) -#define MAXMEMORY_ALLKEYS_LFU ((5<<8)|MAXMEMORY_FLAG_LFU|MAXMEMORY_FLAG_ALLKEYS) -#define MAXMEMORY_ALLKEYS_RANDOM ((6<<8)|MAXMEMORY_FLAG_ALLKEYS) -#define MAXMEMORY_NO_EVICTION (7<<8) - +#define MAXMEMORY_VOLATILE_LRU ((0 << 8) | MAXMEMORY_FLAG_LRU) +#define MAXMEMORY_VOLATILE_LFU ((1 << 8) | MAXMEMORY_FLAG_LFU) +#define MAXMEMORY_VOLATILE_TTL (2 << 8) +#define MAXMEMORY_VOLATILE_RANDOM (3 << 8) +#define MAXMEMORY_ALLKEYS_LRU ((4 << 8) | MAXMEMORY_FLAG_LRU | MAXMEMORY_FLAG_ALLKEYS) +#define MAXMEMORY_ALLKEYS_LFU ((5 << 8) | MAXMEMORY_FLAG_LFU | MAXMEMORY_FLAG_ALLKEYS) +#define MAXMEMORY_ALLKEYS_RANDOM ((6 << 8) | MAXMEMORY_FLAG_ALLKEYS) +#define MAXMEMORY_NO_EVICTION (7 << 8) #define CONFIG_RUN_ID_SIZE 40U #define EVPOOL_CACHED_SDS_SIZE 255 #define EVPOOL_SIZE 16 -int htNeedsResize(dict *dict); // moved from server.cc +int htNeedsResize(dict* dict); // moved from server.cc /* Hash table types */ extern dictType zsetDictType; @@ -52,39 +51,35 @@ extern dictType hashDictType; * Empty entries have the key pointer set to NULL. */ struct evictionPoolEntry { - unsigned long long idle; /* Object idle time (inverse frequency for LFU) */ - sds key; /* Key name. */ - sds cached; /* Cached SDS object for key name. */ - int dbid; /* Key DB number. */ + unsigned long long idle; /* Object idle time (inverse frequency for LFU) */ + sds key; /* Key name. */ + sds cached; /* Cached SDS object for key name. */ + int dbid; /* Key DB number. */ }; -uint64_t dictSdsHash(const void *key); -int dictSdsKeyCompare(dict *privdata, const void *key1, const void *key2); -void dictSdsDestructor(dict *privdata, void *val); -size_t sdsZmallocSize(sds s) ; +uint64_t dictSdsHash(const void* key); +int dictSdsKeyCompare(dict* privdata, const void* key1, const void* key2); +void dictSdsDestructor(dict* privdata, void* val); +size_t sdsZmallocSize(sds s); typedef struct ServerStub { - int rdb_compression; - - int lfu_decay_time; /* LFU counter decay factor. */ - /* should not be used. Use FLAGS_list_max_ziplist_size and FLAGS_list_compress_depth instead. */ - // int list_compress_depth; - // int list_max_ziplist_size; + int lfu_decay_time; /* LFU counter decay factor. */ + /* should not be used. Use FLAGS_list_max_ziplist_size and FLAGS_list_compress_depth instead. */ + // int list_compress_depth; + // int list_max_ziplist_size; - // unused - left so that object.c will compile. - int maxmemory_policy; /* Policy for key eviction */ + // unused - left so that object.c will compile. + int maxmemory_policy; /* Policy for key eviction */ - unsigned long page_size; - size_t hash_max_listpack_entries, - hash_max_listpack_value; - size_t zset_max_listpack_entries; - size_t zset_max_listpack_value; + unsigned long page_size; + size_t hash_max_listpack_entries, hash_max_listpack_value; + size_t zset_max_listpack_entries; + size_t zset_max_listpack_value; - size_t stream_node_max_bytes; - long long stream_node_max_entries; + size_t stream_node_max_bytes; + long long stream_node_max_entries; } Server; - extern Server server; void InitRedisTables(); diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index bbeedda33b44..72592ca4bad7 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -21,7 +21,7 @@ add_library(dragonfly_lib channel_slice.cc command_registry.cc zset_family.cc version.cc bitops_family.cc container_utils.cc) cxx_link(dragonfly_lib dfly_transaction dfly_facade redis_lib strings_lib html_lib - absl::random_random TRDP::jsoncons) + absl::random_random TRDP::jsoncons zstd) add_library(dfly_test_lib test_utils.cc) cxx_link(dfly_test_lib dragonfly_lib epoll_fiber_lib facade_test gtest_main_ext) diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index c385a814403c..ef4d7dd8821b 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -26,6 +26,7 @@ extern "C" { ABSL_FLAG(uint32_t, dbnum, 16, "Number of databases"); ABSL_FLAG(uint32_t, keys_output_limit, 8192, "Maximum number of keys output by keys command"); +ABSL_DECLARE_FLAG(int, compression_mode); namespace dfly { using namespace std; @@ -429,7 +430,8 @@ OpResult OpDump(const OpArgs& op_args, string_view key) { if (IsValid(it)) { DVLOG(1) << "Dump: key '" << key << "' successfully found, going to dump it"; std::unique_ptr<::io::StringSink> sink = std::make_unique<::io::StringSink>(); - RdbSerializer serializer(sink.get()); + int compression_mode = absl::GetFlag(FLAGS_compression_mode); + RdbSerializer serializer(sink.get(), compression_mode != 0); // According to Redis code we need to // 1. Save the value itself - without the key diff --git a/src/server/rdb_extensions.h b/src/server/rdb_extensions.h index 2a9243fd911c..b2e259ca9841 100644 --- a/src/server/rdb_extensions.h +++ b/src/server/rdb_extensions.h @@ -10,3 +10,5 @@ // to notify that it finished streaming static data and is ready // to switch to the stable state replication phase. const uint8_t RDB_OPCODE_FULLSYNC_END = 200; +const uint8_t RDB_OPCODE_COMPRESSED_BLOB_START = 201; +const uint8_t RDB_OPCODE_COMPRESSED_BLOB_END = 202; diff --git a/src/server/rdb_load.cc b/src/server/rdb_load.cc index 1f738187e100..6e5c4e0110ff 100644 --- a/src/server/rdb_load.cc +++ b/src/server/rdb_load.cc @@ -16,9 +16,9 @@ extern "C" { #include "redis/zmalloc.h" #include "redis/zset.h" } - #include #include +#include #include "base/endian.h" #include "base/flags.h" @@ -203,6 +203,59 @@ bool resizeStringSet(robj* set, size_t size, bool use_set2) { } // namespace +class ZstdDecompressImpl { + public: + ZstdDecompressImpl() : uncompressed_mem_buf_{16_KB} { + dctx_ = ZSTD_createDCtx(); + } + ~ZstdDecompressImpl() { + ZSTD_freeDCtx(dctx_); + } + + io::Result Decompress(std::string_view str); + + private: + ZSTD_DCtx* dctx_; + base::IoBuf uncompressed_mem_buf_; +}; + +io::Result ZstdDecompressImpl::Decompress(std::string_view str) { + // Prepare membuf memory to uncompressed string. + unsigned long long const uncomp_size = ZSTD_getFrameContentSize(str.data(), str.size()); + if (uncomp_size == ZSTD_CONTENTSIZE_UNKNOWN) { + LOG(ERROR) << "Zstd compression missing frame content size"; + return Unexpected(errc::invalid_encoding); + } + if (uncomp_size == ZSTD_CONTENTSIZE_ERROR) { + LOG(ERROR) << "Invalid ZSTD compressed string"; + return Unexpected(errc::invalid_encoding); + } + uncompressed_mem_buf_.Reserve(uncomp_size + 1); + + // Uncompress string to membuf + IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); + if (dest.size() < uncomp_size) { + return Unexpected(errc::out_of_memory); + } + size_t const d_size = + ZSTD_decompressDCtx(dctx_, dest.data(), dest.size(), str.data(), str.size()); + if (d_size == 0 || d_size != uncomp_size) { + LOG(ERROR) << "Invalid ZSTD compressed string"; + return Unexpected(errc::rdb_file_corrupted); + } + uncompressed_mem_buf_.CommitWrite(d_size); + + // Add opcode of compressed blob end to membuf. + dest = uncompressed_mem_buf_.AppendBuffer(); + if (dest.size() < 1) { + return Unexpected(errc::out_of_memory); + } + dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END; + uncompressed_mem_buf_.CommitWrite(1); + + return &uncompressed_mem_buf_; +} + class RdbLoaderBase::OpaqueObjLoader { public: OpaqueObjLoader(int rdb_type, PrimeValue* pv) : rdb_type_(rdb_type), pv_(pv) { @@ -243,7 +296,11 @@ class RdbLoaderBase::OpaqueObjLoader { PrimeValue* pv_; }; -RdbLoaderBase::RdbLoaderBase() : mem_buf_{16_KB} { +RdbLoaderBase::RdbLoaderBase() : origin_mem_buf_{16_KB} { + mem_buf_ = &origin_mem_buf_; +} + +RdbLoaderBase::~RdbLoaderBase() { } void RdbLoaderBase::OpaqueObjLoader::operator()(const base::PODArray& str) { @@ -832,11 +889,11 @@ std::error_code RdbLoaderBase::FetchBuf(size_t size, void* dest) { uint8_t* next = (uint8_t*)dest; size_t bytes_read; - size_t to_copy = std::min(mem_buf_.InputLen(), size); + size_t to_copy = std::min(mem_buf_->InputLen(), size); DVLOG(2) << "Copying " << to_copy << " bytes"; - ::memcpy(next, mem_buf_.InputBuffer().data(), to_copy); - mem_buf_.ConsumeInput(to_copy); + ::memcpy(next, mem_buf_->InputBuffer().data(), to_copy); + mem_buf_->ConsumeInput(to_copy); size -= to_copy; if (size == 0) return kOk; @@ -862,7 +919,7 @@ std::error_code RdbLoaderBase::FetchBuf(size_t size, void* dest) { return kOk; } - io::MutableBytes mb = mem_buf_.AppendBuffer(); + io::MutableBytes mb = mem_buf_->AppendBuffer(); // Must be because mem_buf_ is be empty. DCHECK_GT(mb.size(), size); @@ -879,9 +936,9 @@ std::error_code RdbLoaderBase::FetchBuf(size_t size, void* dest) { DCHECK_LE(bytes_read_, source_limit_); - mem_buf_.CommitWrite(bytes_read); - ::memcpy(next, mem_buf_.InputBuffer().data(), size); - mem_buf_.ConsumeInput(size); + mem_buf_->CommitWrite(bytes_read); + ::memcpy(next, mem_buf_->InputBuffer().data(), size); + mem_buf_->ConsumeInput(size); return kOk; } @@ -953,8 +1010,8 @@ auto RdbLoaderBase::FetchLzfStringObject() -> io::Result { return Unexpected(rdb::rdb_file_corrupted); } - if (mem_buf_.InputLen() >= clen) { - cbuf = mem_buf_.InputBuffer().data(); + if (mem_buf_->InputLen() >= clen) { + cbuf = mem_buf_->InputBuffer().data(); } else { compr_buf_.resize(clen); zerocopy_decompress = false; @@ -977,7 +1034,7 @@ auto RdbLoaderBase::FetchLzfStringObject() -> io::Result { // FetchBuf consumes the input but if we have not went through that path // we need to consume now. if (zerocopy_decompress) - mem_buf_.ConsumeInput(clen); + mem_buf_->ConsumeInput(clen); return res; } @@ -1013,7 +1070,7 @@ io::Result RdbLoaderBase::FetchBinaryDouble() { return make_unexpected(ec); uint8_t buf[8]; - mem_buf_.ReadAndConsume(8, buf); + mem_buf_->ReadAndConsume(8, buf); u.val = base::LE::LoadT(buf); return u.d; } @@ -1438,7 +1495,7 @@ template io::Result RdbLoaderBase::FetchInt() { return make_unexpected(ec); char buf[16]; - mem_buf_.ReadAndConsume(sizeof(T), buf); + mem_buf_->ReadAndConsume(sizeof(T), buf); return base::LE::LoadT>(buf); } @@ -1477,7 +1534,7 @@ error_code RdbLoader::Load(io::Source* src) { absl::Time start = absl::Now(); src_ = src; - IoBuf::Bytes bytes = mem_buf_.AppendBuffer(); + IoBuf::Bytes bytes = mem_buf_->AppendBuffer(); io::Result read_sz = src_->ReadAtLeast(bytes, 9); if (!read_sz) return read_sz.error(); @@ -1487,10 +1544,10 @@ error_code RdbLoader::Load(io::Source* src) { return RdbError(errc::wrong_signature); } - mem_buf_.CommitWrite(bytes_read_); + mem_buf_->CommitWrite(bytes_read_); { - auto cb = mem_buf_.InputBuffer(); + auto cb = mem_buf_->InputBuffer(); if (memcmp(cb.data(), "REDIS", 5) != 0) { return RdbError(errc::wrong_signature); @@ -1505,7 +1562,7 @@ error_code RdbLoader::Load(io::Source* src) { return RdbError(errc::bad_version); } - mem_buf_.ConsumeInput(9); + mem_buf_->ConsumeInput(9); } int type; @@ -1606,6 +1663,15 @@ error_code RdbLoader::Load(io::Source* src) { return RdbError(errc::feature_not_supported); } + if (type == RDB_OPCODE_COMPRESSED_BLOB_START) { + RETURN_ON_ERR(HandleCompressedBlob()); + continue; + } + if (type == RDB_OPCODE_COMPRESSED_BLOB_END) { + RETURN_ON_ERR(HandleCompressedBlobFinish()); + continue; + } + if (!rdbIsObjectType(type)) { return RdbError(errc::invalid_rdb_type); } @@ -1640,9 +1706,9 @@ error_code RdbLoader::Load(io::Source* src) { } error_code RdbLoaderBase::EnsureReadInternal(size_t min_sz) { - DCHECK_LT(mem_buf_.InputLen(), min_sz); + DCHECK_LT(mem_buf_->InputLen(), min_sz); - auto out_buf = mem_buf_.AppendBuffer(); + auto out_buf = mem_buf_->AppendBuffer(); CHECK_GT(out_buf.size(), min_sz); // If limit was applied we do not want to read more than needed @@ -1661,7 +1727,7 @@ error_code RdbLoaderBase::EnsureReadInternal(size_t min_sz) { bytes_read_ += *res; DCHECK_LE(bytes_read_, source_limit_); - mem_buf_.CommitWrite(*res); + mem_buf_->CommitWrite(*res); return kOk; } @@ -1677,9 +1743,9 @@ auto RdbLoaderBase::LoadLen(bool* is_encoded) -> io::Result { return make_unexpected(ec); uint64_t res = 0; - uint8_t first = mem_buf_.InputBuffer()[0]; + uint8_t first = mem_buf_->InputBuffer()[0]; int type = (first & 0xC0) >> 6; - mem_buf_.ConsumeInput(1); + mem_buf_->ConsumeInput(1); if (type == RDB_ENCVAL) { /* Read a 6 bit encoding type. */ if (is_encoded) @@ -1689,16 +1755,16 @@ auto RdbLoaderBase::LoadLen(bool* is_encoded) -> io::Result { /* Read a 6 bit len. */ res = first & 0x3F; } else if (type == RDB_14BITLEN) { - res = ((first & 0x3F) << 8) | mem_buf_.InputBuffer()[0]; - mem_buf_.ConsumeInput(1); + res = ((first & 0x3F) << 8) | mem_buf_->InputBuffer()[0]; + mem_buf_->ConsumeInput(1); } else if (first == RDB_32BITLEN) { /* Read a 32 bit len. */ - res = absl::big_endian::Load32(mem_buf_.InputBuffer().data()); - mem_buf_.ConsumeInput(4); + res = absl::big_endian::Load32(mem_buf_->InputBuffer().data()); + mem_buf_->ConsumeInput(4); } else if (first == RDB_64BITLEN) { /* Read a 64 bit len. */ - res = absl::big_endian::Load64(mem_buf_.InputBuffer().data()); - mem_buf_.ConsumeInput(8); + res = absl::big_endian::Load64(mem_buf_->InputBuffer().data()); + mem_buf_->ConsumeInput(8); } else { LOG(ERROR) << "Bad length encoding " << type << " in rdbLoadLen()"; return Unexpected(errc::rdb_file_corrupted); @@ -1707,6 +1773,30 @@ auto RdbLoaderBase::LoadLen(bool* is_encoded) -> io::Result { return res; } +error_code RdbLoaderBase::HandleCompressedBlob() { + if (!zstd_decompress_) { + zstd_decompress_.reset(new ZstdDecompressImpl()); + } + + // Fetch uncompress blob + string res; + SET_OR_RETURN(FetchGenericString(), res); + + // Decompress blob and switch membuf pointer + // Last type in the compressed blob is RDB_OPCODE_COMPRESSED_BLOB_END + // in which we will switch back to the origin membuf (HandleCompressedBlobFinish) + string_view uncompressed_blob; + SET_OR_RETURN(zstd_decompress_->Decompress(res), mem_buf_); + + return kOk; +} + +error_code RdbLoaderBase::HandleCompressedBlobFinish() { + // TODO validate that all uncompressed data was fetched + mem_buf_ = &origin_mem_buf_; + return kOk; +} + error_code RdbLoader::HandleAux() { /* AUX: generic string-string fields. Use to add state to RDB * which is backward compatible. Implementations of RDB loading @@ -1777,7 +1867,7 @@ error_code RdbLoader::VerifyChecksum() { SET_OR_RETURN(FetchInt(), expected); - io::Bytes cur_buf = mem_buf_.InputBuffer(); + io::Bytes cur_buf = mem_buf_->InputBuffer(); VLOG(1) << "VerifyChecksum: input buffer len " << cur_buf.size() << ", expected " << expected; diff --git a/src/server/rdb_load.h b/src/server/rdb_load.h index ff998f06eb6f..77136b60844f 100644 --- a/src/server/rdb_load.h +++ b/src/server/rdb_load.h @@ -21,9 +21,12 @@ class EngineShardSet; class ScriptMgr; class CompactObj; +class ZstdDecompressImpl; + class RdbLoaderBase { protected: RdbLoaderBase(); + ~RdbLoaderBase(); struct LoadTrace; using MutableBytes = ::io::MutableBytes; @@ -124,11 +127,13 @@ class RdbLoaderBase { ::io::Result ReadZSetZL(); ::io::Result ReadListQuicklist(int rdbtype); ::io::Result ReadStreams(); + std::error_code HandleCompressedBlob(); + std::error_code HandleCompressedBlobFinish(); static size_t StrLen(const RdbVariant& tset); std::error_code EnsureRead(size_t min_sz) { - if (mem_buf_.InputLen() >= min_sz) + if (mem_buf_->InputLen() >= min_sz) return std::error_code{}; return EnsureReadInternal(min_sz); @@ -137,11 +142,14 @@ class RdbLoaderBase { std::error_code EnsureReadInternal(size_t min_sz); protected: - base::IoBuf mem_buf_; + base::IoBuf* mem_buf_ = nullptr; + base::IoBuf origin_mem_buf_; ::io::Source* src_ = nullptr; + size_t bytes_read_ = 0; size_t source_limit_ = SIZE_MAX; base::PODArray compr_buf_; + std::unique_ptr zstd_decompress_; }; class RdbLoader : protected RdbLoaderBase { @@ -156,7 +164,7 @@ class RdbLoader : protected RdbLoaderBase { } ::io::Bytes Leftover() const { - return mem_buf_.InputBuffer(); + return mem_buf_->InputBuffer(); } size_t bytes_read() const { diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index b505fae4ec47..ba0a1a9c225f 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/string_set.h" @@ -21,6 +22,7 @@ extern "C" { #include "redis/zset.h" } +#include "base/flags.h" #include "base/logging.h" #include "server/engine_shard_set.h" #include "server/error.h" @@ -28,6 +30,12 @@ extern "C" { #include "server/snapshot.h" #include "util/fibers/simple_channel.h" +ABSL_FLAG(int, compression_mode, 2, + "set 0 for no compression," + "set 1 for single entry lzf compression," + "set 2 for multi entry zstd compression on df snapshot and single entry on rdb snapshot"); +ABSL_FLAG(int, zstd_compression_level, 2, "Compression level to use on zstd compression"); + namespace dfly { using namespace std; @@ -158,7 +166,8 @@ uint8_t RdbObjectType(unsigned type, unsigned encoding) { return 0; /* avoid warning */ } -RdbSerializer::RdbSerializer(io::Sink* s) : sink_(s), mem_buf_{4_KB}, tmp_buf_(nullptr) { +RdbSerializer::RdbSerializer(io::Sink* s, bool do_compression) + : sink_(s), mem_buf_{4_KB}, tmp_buf_(nullptr), do_entry_level_compression_(do_compression) { } RdbSerializer::~RdbSerializer() { @@ -639,7 +648,7 @@ error_code RdbSerializer::SaveString(string_view val) { /* Try LZF compression - under 20 bytes it's unable to compress even * aaaaaaaaaaaaaaaaaa so skip it */ size_t len = val.size(); - if (server.rdb_compression && len > 20) { + if (do_entry_level_compression_ && len > 20) { size_t comprlen, outlen = len; tmp_buf_.resize(outlen + 1); @@ -737,7 +746,7 @@ class RdbSaver::Impl { public: // We pass K=sz to say how many producers are pushing data in order to maintain // correct closing semantics - channel is closing when K producers marked it as closed. - Impl(bool align_writes, unsigned producers_len, io::Sink* sink); + Impl(bool align_writes, unsigned producers_len, CompressionMode compression_mode, io::Sink* sink); void StartSnapshotting(bool stream_journal, const Cancellation* cll, EngineShard* shard); @@ -775,13 +784,20 @@ class RdbSaver::Impl { RdbSerializer meta_serializer_; SliceSnapshot::RecordChannel channel_; std::optional aligned_buf_; + CompressionMode + compression_mode_; // Single entry compression is compatible with redis rdb snapshot + // Multi entry compression is available only on df snapshot, this will + // make snapshot size smaller and opreation faster. }; // We pass K=sz to say how many producers are pushing data in order to maintain // correct closing semantics - channel is closing when K producers marked it as closed. -RdbSaver::Impl::Impl(bool align_writes, unsigned producers_len, io::Sink* sink) +RdbSaver::Impl::Impl(bool align_writes, unsigned producers_len, CompressionMode compression_mode, + io::Sink* sink) : sink_(sink), shard_snapshots_(producers_len), - meta_serializer_(sink), channel_{128, producers_len} { + meta_serializer_(sink, compression_mode != CompressionMode::NONE), channel_{128, + producers_len}, + compression_mode_(compression_mode) { if (align_writes) { aligned_buf_.emplace(kBufLen, sink); meta_serializer_.set_sink(&aligned_buf_.value()); @@ -863,7 +879,7 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) { void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll, EngineShard* shard) { auto& s = GetSnapshot(shard); - s.reset(new SliceSnapshot(&shard->db_slice(), &channel_)); + s.reset(new SliceSnapshot(&shard->db_slice(), &channel_, compression_mode_)); s->Start(stream_journal, cll); } @@ -906,21 +922,38 @@ unique_ptr& RdbSaver::Impl::GetSnapshot(EngineShard* shard) { RdbSaver::RdbSaver(::io::Sink* sink, SaveMode save_mode, bool align_writes) { CHECK_NOTNULL(sink); - + int compression_mode = absl::GetFlag(FLAGS_compression_mode); int producer_count = 0; switch (save_mode) { case SaveMode::SUMMARY: producer_count = 0; + if (compression_mode == 1 || compression_mode == 2) { + compression_mode_ = CompressionMode::SINGLE_ENTRY; + } else { + compression_mode_ = CompressionMode::NONE; + } break; case SaveMode::SINGLE_SHARD: producer_count = 1; + if (compression_mode == 2) { + compression_mode_ = CompressionMode::MULTY_ENTRY; + } else if (compression_mode == 1) { + compression_mode_ = CompressionMode::SINGLE_ENTRY; + } else { + compression_mode_ = CompressionMode::NONE; + } break; case SaveMode::RDB: producer_count = shard_set->size(); + if (compression_mode == 1 || compression_mode == 2) { + compression_mode_ = CompressionMode::SINGLE_ENTRY; + } else { + compression_mode_ = CompressionMode::NONE; + } break; } - impl_.reset(new Impl(align_writes, producer_count, sink)); + impl_.reset(new Impl(align_writes, producer_count, compression_mode_, sink)); save_mode_ = save_mode; } @@ -1028,4 +1061,75 @@ void RdbSaver::Cancel() { impl_->Cancel(); } +class ZstdCompressSerializer::ZstdCompressImpl { + public: + ZstdCompressImpl() { + cctx_ = ZSTD_createCCtx(); + compression_level_ = absl::GetFlag(FLAGS_zstd_compression_level); + } + ~ZstdCompressImpl() { + ZSTD_freeCCtx(cctx_); + + VLOG(1) << "zstd compressed size: " << compressed_size_total_; + VLOG(1) << "zstd uncompressed size: " << uncompressed_size_total_; + } + + std::string_view Compress(std::string_view str); + + private: + ZSTD_CCtx* cctx_; + int compression_level_ = 1; + base::PODArray compr_buf_; + uint32_t compressed_size_total_ = 0; + uint32_t uncompressed_size_total_ = 0; +}; + +std::string_view ZstdCompressSerializer::ZstdCompressImpl::Compress(string_view str) { + size_t buf_size = ZSTD_compressBound(str.size()); + if (compr_buf_.size() < buf_size) { + compr_buf_.reserve(buf_size); + } + size_t compressed_size = ZSTD_compressCCtx(cctx_, compr_buf_.data(), compr_buf_.capacity(), + str.data(), str.size(), compression_level_); + + compressed_size_total_ += compressed_size; + uncompressed_size_total_ += str.size(); + return string_view(reinterpret_cast(compr_buf_.data()), compressed_size); +} + +ZstdCompressSerializer::ZstdCompressSerializer() { + impl_.reset(new ZstdCompressImpl()); +} + +std::pair ZstdCompressSerializer::Compress(std::string_view str) { + if (str.size() < kMinStrSizeToCompress) { + ++small_str_count_; + return std::make_pair(false, ""); + } + + // Compress the string + string_view compressed_res = impl_->Compress(str); + if (compressed_res.size() > str.size() * kMinCompressionReductionPrecentage) { + ++compression_no_effective_; + return std::make_pair(false, ""); + } + + string serialized_compressed_blob; + // First write opcode for compressed string + serialized_compressed_blob.push_back(RDB_OPCODE_COMPRESSED_BLOB_START); + // Get compressed string len encoded + uint8_t buf[9]; + unsigned enclen = SerializeLen(compressed_res.size(), buf); + + // Write encoded compressed string len and than the compressed string + serialized_compressed_blob.append(reinterpret_cast(buf), enclen); + serialized_compressed_blob.append(compressed_res); + return std::make_pair(true, std::move(serialized_compressed_blob)); +} + +ZstdCompressSerializer::~ZstdCompressSerializer() { + VLOG(1) << "zstd compression not effective: " << compression_no_effective_; + VLOG(1) << "small string none compression applied: " << small_str_count_; +} + } // namespace dfly diff --git a/src/server/rdb_save.h b/src/server/rdb_save.h index b6b04f453dd0..f54420660f10 100644 --- a/src/server/rdb_save.h +++ b/src/server/rdb_save.h @@ -59,6 +59,12 @@ enum class SaveMode { RDB, // Save .rdb file. Expected to read all shards. }; +enum class CompressionMode { + NONE, + SINGLE_ENTRY, + MULTY_ENTRY, +}; + class RdbSaver { public: // single_shard - true means that we run RdbSaver on a single shard and we do not use @@ -101,6 +107,29 @@ class RdbSaver { SaveMode save_mode_; std::unique_ptr impl_; + CompressionMode compression_mode_; +}; + +class ZstdCompressSerializer { + public: + ZstdCompressSerializer(); + ZstdCompressSerializer(const ZstdCompressSerializer&) = delete; + void operator=(const ZstdCompressSerializer&) = delete; + + ~ZstdCompressSerializer(); + + // Returns a pair consisting of an bool denoting whether the string was compressed + // and a string the result of compression. If given string was not compressed returned + // string will be empty. + std::pair Compress(std::string_view str); + + private: + class ZstdCompressImpl; + std::unique_ptr impl_; + static constexpr size_t kMinStrSizeToCompress = 256; + static constexpr double kMinCompressionReductionPrecentage = 0.95; + uint32_t compression_no_effective_ = 0; + uint32_t small_str_count_ = 0; }; class RdbSerializer { @@ -108,7 +137,7 @@ class RdbSerializer { // TODO: for aligned cased, it does not make sense that RdbSerializer buffers into unaligned // mem_buf_ and then flush it into the next level. We should probably use AlignedBuffer // directly. - RdbSerializer(::io::Sink* s); + RdbSerializer(::io::Sink* s, bool do_entry_level_compression); ~RdbSerializer(); @@ -166,6 +195,7 @@ class RdbSerializer { base::IoBuf mem_buf_; base::PODArray tmp_buf_; std::string tmp_str_; + bool do_entry_level_compression_; }; } // namespace dfly diff --git a/src/server/snapshot.cc b/src/server/snapshot.cc index 8722c05ac984..f6f58721c1b2 100644 --- a/src/server/snapshot.cc +++ b/src/server/snapshot.cc @@ -15,6 +15,7 @@ extern "C" { #include "server/db_slice.h" #include "server/engine_shard_set.h" #include "server/journal/journal.h" +#include "server/rdb_extensions.h" #include "server/rdb_save.h" #include "util/fiber_sched_algo.h" #include "util/proactor_base.h" @@ -27,7 +28,8 @@ using namespace chrono_literals; namespace this_fiber = ::boost::this_fiber; using boost::fibers::fiber; -SliceSnapshot::SliceSnapshot(DbSlice* slice, RecordChannel* dest) : db_slice_(slice), dest_(dest) { +SliceSnapshot::SliceSnapshot(DbSlice* slice, RecordChannel* dest, CompressionMode compression_mode) + : db_slice_(slice), dest_(dest), compression_mode_(compression_mode) { db_array_ = slice->databases(); } @@ -52,7 +54,9 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) { } sfile_.reset(new io::StringFile); - rdb_serializer_.reset(new RdbSerializer(sfile_.get())); + + bool do_compression = (compression_mode_ == CompressionMode::SINGLE_ENTRY); + rdb_serializer_.reset(new RdbSerializer(sfile_.get(), do_compression)); snapshot_fb_ = fiber([this, stream_journal, cll] { SerializeEntriesFb(cll); @@ -197,9 +201,10 @@ bool SliceSnapshot::FlushSfile(bool force) { } VLOG(2) << "FlushSfile " << sfile_->val.size() << " bytes"; - DbRecord rec = GetDbRecord(savecb_current_db_, std::move(sfile_->val), num_records_in_blob_); + uint32_t record_num = num_records_in_blob_; num_records_in_blob_ = 0; // We can not move this line after the push, because Push is blocking. - dest_->Push(std::move(rec)); + bool multi_entries_compression = (compression_mode_ == CompressionMode::MULTY_ENTRY); + PushFileToChannel(sfile_.get(), savecb_current_db_, record_num, multi_entries_compression); return true; } @@ -266,17 +271,15 @@ void SliceSnapshot::OnJournalEntry(const journal::Entry& entry) { CHECK(res); // we write to StringFile. } else { io::StringFile sfile; - RdbSerializer tmp_serializer(&sfile); + bool serializer_compression = (compression_mode_ != CompressionMode::NONE); + RdbSerializer tmp_serializer(&sfile, serializer_compression); io::Result res = tmp_serializer.SaveEntry(pkey, *entry.pval_ptr, entry.expire_ms); CHECK(res); // we write to StringFile. error_code ec = tmp_serializer.FlushMem(); CHECK(!ec && !sfile.val.empty()); - - DbRecord rec = GetDbRecord(entry.db_ind, std::move(sfile.val), 1); - - dest_->Push(std::move(rec)); + PushFileToChannel(&sfile, entry.db_ind, 1, false); } } @@ -298,7 +301,8 @@ unsigned SliceSnapshot::SerializePhysicalBucket(DbIndex db_index, PrimeTable::bu num_records_in_blob_ += result; } else { io::StringFile sfile; - RdbSerializer tmp_serializer(&sfile); + bool serializer_compression = (compression_mode_ != CompressionMode::NONE); + RdbSerializer tmp_serializer(&sfile, serializer_compression); while (!it.is_done()) { ++result; @@ -307,12 +311,27 @@ unsigned SliceSnapshot::SerializePhysicalBucket(DbIndex db_index, PrimeTable::bu } error_code ec = tmp_serializer.FlushMem(); CHECK(!ec && !sfile.val.empty()); - - dest_->Push(GetDbRecord(db_index, std::move(sfile.val), result)); + PushFileToChannel(&sfile, db_index, result, false); } return result; } +void SliceSnapshot::PushFileToChannel(io::StringFile* sfile, DbIndex db_index, unsigned num_records, + bool should_compress) { + string string_to_push = std::move(sfile->val); + + if (should_compress) { + if (!zstd_serializer_) { + zstd_serializer_.reset(new ZstdCompressSerializer()); + } + auto comp_res = zstd_serializer_->Compress(string_to_push); + if (comp_res.first) { + string_to_push.swap(comp_res.second); + } + } + dest_->Push(GetDbRecord(db_index, std::move(string_to_push), num_records)); +} + auto SliceSnapshot::GetDbRecord(DbIndex db_index, std::string value, unsigned num_records) -> DbRecord { channel_bytes_ += value.size(); diff --git a/src/server/snapshot.h b/src/server/snapshot.h index a2d2c15d03b7..5fd336646bc3 100644 --- a/src/server/snapshot.h +++ b/src/server/snapshot.h @@ -7,8 +7,10 @@ #include #include +#include "base/pod_array.h" #include "io/file.h" #include "server/db_slice.h" +#include "server/rdb_save.h" #include "server/table.h" #include "util/fibers/simple_channel.h" @@ -19,6 +21,7 @@ struct Entry; } // namespace journal class RdbSerializer; +class ZstdCompressSerializer; class SliceSnapshot { public: @@ -34,7 +37,7 @@ class SliceSnapshot { using RecordChannel = ::util::fibers_ext::SimpleChannel>; - SliceSnapshot(DbSlice* slice, RecordChannel* dest); + SliceSnapshot(DbSlice* slice, RecordChannel* dest, CompressionMode compression_mode); ~SliceSnapshot(); void Start(bool stream_journal, const Cancellation* cll); @@ -63,7 +66,6 @@ class SliceSnapshot { private: void CloseRecordChannel(); - void SerializeEntriesFb(const Cancellation* cll); void SerializeSingleEntry(DbIndex db_index, const PrimeKey& pk, const PrimeValue& pv, @@ -79,6 +81,8 @@ class SliceSnapshot { // Updates the version of the bucket to snapshot version. unsigned SerializePhysicalBucket(DbIndex db_index, PrimeTable::bucket_iterator it); DbRecord GetDbRecord(DbIndex db_index, std::string value, unsigned num_records); + void PushFileToChannel(io::StringFile* sfile, DbIndex db_index, unsigned num_records, + bool should_compress); DbSlice* db_slice_; DbTableArray db_array_; @@ -98,6 +102,9 @@ class SliceSnapshot { uint64_t rec_id_ = 0; uint32_t num_records_in_blob_ = 0; + CompressionMode compression_mode_; + std::unique_ptr zstd_serializer_; + uint32_t journal_cb_id_ = 0; ::boost::fibers::fiber snapshot_fb_;