Skip to content

Commit

Permalink
Implement one-shot fallback for magicless format (#3971)
Browse files Browse the repository at this point in the history
  • Loading branch information
embg authored Mar 18, 2024
1 parent a595e58 commit 7d970bd
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
22 changes: 13 additions & 9 deletions lib/decompress/zstd_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -729,17 +729,17 @@ static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret)
return frameSizeInfo;
}

static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize)
static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize, ZSTD_format_e format)
{
ZSTD_frameSizeInfo frameSizeInfo;
ZSTD_memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo));

#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1)
if (ZSTD_isLegacy(src, srcSize))
if (format == ZSTD_f_zstd1 && ZSTD_isLegacy(src, srcSize))
return ZSTD_findFrameSizeInfoLegacy(src, srcSize);
#endif

if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE)
if (format == ZSTD_f_zstd1 && (srcSize >= ZSTD_SKIPPABLEHEADERSIZE)
&& (MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
frameSizeInfo.compressedSize = readSkippableFrameSize(src, srcSize);
assert(ZSTD_isError(frameSizeInfo.compressedSize) ||
Expand All @@ -753,7 +753,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
ZSTD_frameHeader zfh;

/* Extract Frame Header */
{ size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize);
{ size_t const ret = ZSTD_getFrameHeader_advanced(&zfh, src, srcSize, format);
if (ZSTD_isError(ret))
return ZSTD_errorFrameSizeInfo(ret);
if (ret > 0)
Expand Down Expand Up @@ -796,13 +796,17 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
}
}

static size_t ZSTD_findFrameCompressedSize_advanced(const void *src, size_t srcSize, ZSTD_format_e format) {
ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, format);
return frameSizeInfo.compressedSize;
}

/** ZSTD_findFrameCompressedSize() :
* See docs in zstd.h
* Note: compatible with legacy mode */
size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
{
ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
return frameSizeInfo.compressedSize;
return ZSTD_findFrameCompressedSize_advanced(src, srcSize, ZSTD_f_zstd1);
}

/** ZSTD_decompressBound() :
Expand All @@ -816,7 +820,7 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
unsigned long long bound = 0;
/* Iterate over each frame */
while (srcSize > 0) {
ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1);
size_t const compressedSize = frameSizeInfo.compressedSize;
unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR)
Expand All @@ -836,7 +840,7 @@ size_t ZSTD_decompressionMargin(void const* src, size_t srcSize)

/* Iterate over each frame */
while (srcSize > 0) {
ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1);
size_t const compressedSize = frameSizeInfo.compressedSize;
unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
ZSTD_frameHeader zfh;
Expand Down Expand Up @@ -2178,7 +2182,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
if (zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN
&& zds->fParams.frameType != ZSTD_skippableFrame
&& (U64)(size_t)(oend-op) >= zds->fParams.frameContentSize) {
size_t const cSize = ZSTD_findFrameCompressedSize(istart, (size_t)(iend-istart));
size_t const cSize = ZSTD_findFrameCompressedSize_advanced(istart, (size_t)(iend-istart), zds->format);
if (cSize <= (size_t)(iend-istart)) {
/* shortcut : using single-pass mode */
size_t const decompressedSize = ZSTD_decompress_usingDDict(zds, op, (size_t)(oend-op), istart, cSize, ZSTD_getDDict(zds));
Expand Down
35 changes: 35 additions & 0 deletions tests/zstreamtest.c
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,41 @@ static int basicUnitTests(U32 seed, double compressibility, int bigTests)
}
DISPLAYLEVEL(3, "OK \n");

DISPLAYLEVEL(3, "test%3i : Test single-shot fallback for magicless mode: ", testNb++);
{
// Aquire resources
size_t const srcSize = COMPRESSIBLE_NOISE_LENGTH;
void* src = malloc(srcSize);
size_t const dstSize = ZSTD_compressBound(srcSize);
void* dst = malloc(dstSize);
size_t const valSize = srcSize;
void* val = malloc(valSize);
ZSTD_inBuffer inBuf = { dst, dstSize, 0 };
ZSTD_outBuffer outBuf = { val, valSize, 0 };
ZSTD_CCtx* cctx = ZSTD_createCCtx();
ZSTD_DCtx* dctx = ZSTD_createDCtx();
CHECK(!src || !dst || !val || !dctx || !cctx, "memory allocation failure");

// Write test data for decompression to dst
RDG_genBuffer(src, srcSize, compressibility, 0.0, 0xdeadbeef);
CHECK_Z(ZSTD_CCtx_setParameter(cctx, ZSTD_c_format, ZSTD_f_zstd1_magicless));
CHECK_Z(ZSTD_compress2(cctx, dst, dstSize, src, srcSize));

// Run decompression
CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_format, ZSTD_f_zstd1_magicless));
CHECK_Z(ZSTD_decompressStream(dctx, &outBuf, &inBuf));

// Validate
CHECK(outBuf.pos != srcSize, "decompressed size must match");
CHECK(memcmp(src, val, srcSize) != 0, "decompressed data must match");

// Cleanup
free(src); free(dst); free(val);
ZSTD_freeCCtx(cctx);
ZSTD_freeDCtx(dctx);
}
DISPLAYLEVEL(3, "OK \n");

_end:
FUZ_freeDictionary(dictionary);
ZSTD_freeCStream(zc);
Expand Down

0 comments on commit 7d970bd

Please sign in to comment.