From 93468ad60cf65ffd8e7dad65f23bfa4b270277f6 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Fri, 15 Mar 2024 10:50:27 -0700 Subject: [PATCH] Implement one-shot fallback for magicless format --- lib/decompress/zstd_decompress.c | 22 ++++++++++++-------- tests/zstreamtest.c | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index f657974385..ee2cda3b63 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -729,17 +729,17 @@ static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret) return frameSizeInfo; } -static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize) +static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize, ZSTD_format_e format) { ZSTD_frameSizeInfo frameSizeInfo; ZSTD_memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo)); #if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1) - if (ZSTD_isLegacy(src, srcSize)) + if (format == ZSTD_f_zstd1 && ZSTD_isLegacy(src, srcSize)) return ZSTD_findFrameSizeInfoLegacy(src, srcSize); #endif - if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE) + if (format == ZSTD_f_zstd1 && (srcSize >= ZSTD_SKIPPABLEHEADERSIZE) && (MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { frameSizeInfo.compressedSize = readSkippableFrameSize(src, srcSize); assert(ZSTD_isError(frameSizeInfo.compressedSize) || @@ -753,7 +753,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize ZSTD_frameHeader zfh; /* Extract Frame Header */ - { size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize); + { size_t const ret = ZSTD_getFrameHeader_advanced(&zfh, src, srcSize, format); if (ZSTD_isError(ret)) return ZSTD_errorFrameSizeInfo(ret); if (ret > 0) @@ -796,13 +796,17 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize } } +static size_t ZSTD_findFrameCompressedSize_advanced(const void *src, size_t srcSize, ZSTD_format_e format) { + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, format); + return frameSizeInfo.compressedSize; +} + /** ZSTD_findFrameCompressedSize() : * See docs in zstd.h * Note: compatible with legacy mode */ size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize) { - ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); - return frameSizeInfo.compressedSize; + return ZSTD_findFrameCompressedSize_advanced(src, srcSize, ZSTD_f_zstd1); } /** ZSTD_decompressBound() : @@ -816,7 +820,7 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) unsigned long long bound = 0; /* Iterate over each frame */ while (srcSize > 0) { - ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1); size_t const compressedSize = frameSizeInfo.compressedSize; unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) @@ -836,7 +840,7 @@ size_t ZSTD_decompressionMargin(void const* src, size_t srcSize) /* Iterate over each frame */ while (srcSize > 0) { - ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1); size_t const compressedSize = frameSizeInfo.compressedSize; unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; ZSTD_frameHeader zfh; @@ -2178,7 +2182,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB if (zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN && zds->fParams.frameType != ZSTD_skippableFrame && (U64)(size_t)(oend-op) >= zds->fParams.frameContentSize) { - size_t const cSize = ZSTD_findFrameCompressedSize(istart, (size_t)(iend-istart)); + size_t const cSize = ZSTD_findFrameCompressedSize_advanced(istart, (size_t)(iend-istart), zds->format); if (cSize <= (size_t)(iend-istart)) { /* shortcut : using single-pass mode */ size_t const decompressedSize = ZSTD_decompress_usingDDict(zds, op, (size_t)(oend-op), istart, cSize, ZSTD_getDDict(zds)); diff --git a/tests/zstreamtest.c b/tests/zstreamtest.c index 7cc4068bc0..e0ee4c3e93 100644 --- a/tests/zstreamtest.c +++ b/tests/zstreamtest.c @@ -2417,6 +2417,41 @@ static int basicUnitTests(U32 seed, double compressibility, int bigTests) } DISPLAYLEVEL(3, "OK \n"); + DISPLAYLEVEL(3, "test%3i : Test single-shot fallback for magicless mode: ", testNb++); + { + // Aquire resources + size_t const srcSize = COMPRESSIBLE_NOISE_LENGTH; + void* src = malloc(srcSize); + size_t const dstSize = ZSTD_compressBound(srcSize); + void* dst = malloc(dstSize); + size_t const valSize = srcSize; + void* val = malloc(valSize); + ZSTD_inBuffer inBuf = { dst, dstSize, 0 }; + ZSTD_outBuffer outBuf = { val, valSize, 0 }; + ZSTD_CCtx* cctx = ZSTD_createCCtx(); + ZSTD_DCtx* dctx = ZSTD_createDCtx(); + CHECK(!src || !dst || !val || !dctx || !cctx, "memory allocation failure"); + + // Write test data for decompression to dst + RDG_genBuffer(src, srcSize, compressibility, 0.0, 0xdeadbeef); + CHECK_Z(ZSTD_CCtx_setParameter(cctx, ZSTD_c_format, ZSTD_f_zstd1_magicless)); + CHECK_Z(ZSTD_compress2(cctx, dst, dstSize, src, srcSize)); + + // Run decompression + CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_format, ZSTD_f_zstd1_magicless)); + CHECK_Z(ZSTD_decompressStream(dctx, &outBuf, &inBuf)); + + // Validate + CHECK(outBuf.pos != srcSize, "decompressed size must match"); + CHECK(memcmp(src, val, srcSize) != 0, "decompressed data must match"); + + // Cleanup + free(src); free(dst); free(val); + ZSTD_freeCCtx(cctx); + ZSTD_freeDCtx(dctx); + } + DISPLAYLEVEL(3, "OK \n"); + _end: FUZ_freeDictionary(dictionary); ZSTD_freeCStream(zc);