Skip to content

Commit

Permalink
Streaming decompression can detect incorrect header ID sooner
Browse files Browse the repository at this point in the history
Streaming decompression used to wait for a minimum of 5 bytes before attempting decoding.
This meant that, in the case that only a few bytes (<5) were provided,
and assuming these bytes are incorrect,
there would be no error reported.
The streaming API would simply request more data, waiting for at least 5 bytes.

This PR makes it possible to detect incorrect Frame IDs as soon as the first byte is provided.

Fix #3169
  • Loading branch information
Cyan4973 committed Jun 22, 2022
1 parent f6ef143 commit 53aadc7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
52 changes: 39 additions & 13 deletions lib/decompress/zstd_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@
*************************************/

#define DDICT_HASHSET_MAX_LOAD_FACTOR_COUNT_MULT 4
#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float.
* Currently, that means a 0.75 load factor.
* So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded
* the load factor of the ddict hash set.
*/
#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float.
* Currently, that means a 0.75 load factor.
* So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded
* the load factor of the ddict hash set.
*/

#define DDICT_HASHSET_TABLE_BASE_SIZE 64
#define DDICT_HASHSET_RESIZE_FACTOR 2
Expand Down Expand Up @@ -439,16 +439,40 @@ size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize)
* note : only works for formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless
* @return : 0, `zfhPtr` is correctly filled,
* >0, `srcSize` is too small, value is wanted `srcSize` amount,
* or an error code, which can be tested using ZSTD_isError() */
** or an error code, which can be tested using ZSTD_isError() */
size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format)
{
const BYTE* ip = (const BYTE*)src;
size_t const minInputSize = ZSTD_startingInputLength(format);

ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzer do not understand that zfhPtr is only going to be read only if return value is zero, since they are 2 different signals */
if (srcSize < minInputSize) return minInputSize;
RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter");
DEBUGLOG(5, "ZSTD_getFrameHeader_advanced: minInputSize = %zu, srcSize = %zu", minInputSize, srcSize);

if (srcSize > 0) {
/* note : technically could be considered an assert(), since it's an invalid entry */
RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter : src==NULL, but srcSize>0");
}
if (srcSize < minInputSize) {
if (srcSize > 0 && format != ZSTD_f_zstd1_magicless) {
/* when receiving less than @minInputSize bytes,
* control these bytes at least correspond to a supported magic number
* in order to error out early if they don't.
**/
size_t const toCopy = MIN(4, srcSize);
unsigned char hbuf[4]; MEM_writeLE32(hbuf, ZSTD_MAGICNUMBER);
assert(src != NULL);
ZSTD_memcpy(hbuf, src, toCopy);
if ( MEM_readLE32(hbuf) != ZSTD_MAGICNUMBER ) {
/* not a zstd frame : let's check if it's a skippable frame */
MEM_writeLE32(hbuf, ZSTD_MAGIC_SKIPPABLE_START);
ZSTD_memcpy(hbuf, src, toCopy);
if ((MEM_readLE32(hbuf) & ZSTD_MAGIC_SKIPPABLE_MASK) != ZSTD_MAGIC_SKIPPABLE_START) {
RETURN_ERROR(prefix_unknown,
"first bytes don't correspond to any supported magic number");
} } }
return minInputSize;
}

ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzers may not understand that zfhPtr will be read only if return value is zero, since they are 2 different signals */
if ( (format != ZSTD_f_zstd1_magicless)
&& (MEM_readLE32(src) != ZSTD_MAGICNUMBER) ) {
if ((MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
Expand All @@ -459,9 +483,7 @@ size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, s
zfhPtr->frameContentSize = MEM_readLE32((const char *)src + ZSTD_FRAMEIDSIZE);
zfhPtr->frameType = ZSTD_skippableFrame;
return 0;
}
RETURN_ERROR(prefix_unknown, "");
}
} }

/* ensure there is enough `srcSize` to fully read/decode frame header */
{ size_t const fhsize = ZSTD_frameHeaderSize_internal(src, srcSize, format);
Expand Down Expand Up @@ -1981,7 +2003,6 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
if (zds->refMultipleDDicts && zds->ddictSet) {
ZSTD_DCtx_selectFrameDDict(zds);
}
DEBUGLOG(5, "header size : %u", (U32)hSize);
if (ZSTD_isError(hSize)) {
#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT>=1)
U32 const legacyVersion = ZSTD_isLegacy(istart, iend-istart);
Expand Down Expand Up @@ -2013,6 +2034,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
zds->lhSize += remainingInput;
}
input->pos = input->size;
/* check first few bytes */
FORWARD_IF_ERROR(
ZSTD_getFrameHeader_advanced(&zds->fParams, zds->headerBuffer, zds->lhSize, zds->format),
"First few bytes detected incorrect" );
/* return hint input size */
return (MAX((size_t)ZSTD_FRAMEHEADERSIZE_MIN(zds->format), hSize) - zds->lhSize) + ZSTD_blockHeaderSize; /* remaining header bytes + next block header */
}
assert(ip != NULL);
Expand Down
9 changes: 9 additions & 0 deletions tests/zstreamtest.c
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,15 @@ static int basicUnitTests(U32 seed, double compressibility)
} }
DISPLAYLEVEL(3, "OK \n");

/* check decompression fails early if first bytes are wrong */
DISPLAYLEVEL(3, "test%3i : early decompression error if first bytes are incorrect : ", testNb++);
{ const char buf[3] = { 0 }; /* too short, not enough to start decoding header */
ZSTD_inBuffer inb = { buf, sizeof(buf), 0 };
size_t const remaining = ZSTD_decompressStream(zd, &outBuff, &inb);
if (!ZSTD_isError(remaining)) goto _output_error; /* should have errored out immediately (note: this does not test the exact error code) */
}
DISPLAYLEVEL(3, "OK \n");

/* context size functions */
DISPLAYLEVEL(3, "test%3i : estimate DStream size : ", testNb++);
{ ZSTD_frameHeader fhi;
Expand Down

0 comments on commit 53aadc7

Please sign in to comment.