Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[huf] Improve fast C & ASM performance on small data #3827

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 50 additions & 39 deletions lib/decompress/huf_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,18 @@ static size_t HUF_initFastDStream(BYTE const* ip) {
* op [in/out] - The output pointers, must be updated to reflect what is written.
* bits [in/out] - The bitstream containers, must be updated to reflect the current state.
* dt [in] - The decoding table.
* ilimit [in] - The input limit, stop when any input pointer is below ilimit.
* ilowest [in] - The beginning of the valid range of the input. Decoders may read
* down to this pointer. It may be below iend[0].
* oend [in] - The end of the output stream. op[3] must not cross oend.
* iend [in] - The end of each input stream. ip[i] may cross iend[i],
* as long as it is above ilimit, but that indicates corruption.
* as long as it is above ilowest, but that indicates corruption.
*/
typedef struct {
BYTE const* ip[4];
BYTE* op[4];
U64 bits[4];
void const* dt;
BYTE const* ilimit;
BYTE const* ilowest;
BYTE* oend;
BYTE const* iend[4];
} HUF_DecompressFastArgs;
Expand All @@ -192,7 +193,7 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds
void const* dt = DTable + 1;
U32 const dtLog = HUF_getDTableDesc(DTable).tableLog;

const BYTE* const ilimit = (const BYTE*)src + 6 + 8;
const BYTE* const istart = (const BYTE*)src;

BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize);

Expand All @@ -202,6 +203,11 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds
if (!MEM_isLittleEndian() || MEM_32bits())
return 0;

/* Avoid nullptr addition */
if (dstSize == 0)
return 0;
assert(dst != NULL);

/* strict minimum : jump table + 1 byte per stream */
if (srcSize < 10)
return ERROR(corruption_detected);
Expand All @@ -215,7 +221,6 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds

/* Read the jump table. */
{
const BYTE* const istart = (const BYTE*)src;
size_t const length1 = MEM_readLE16(istart);
size_t const length2 = MEM_readLE16(istart+2);
size_t const length3 = MEM_readLE16(istart+4);
Expand All @@ -227,10 +232,8 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds

/* HUF_initFastDStream() requires this, and this small of an input
* won't benefit from the ASM loop anyways.
* length1 must be >= 16 so that ip[0] >= ilimit before the loop
* starts.
*/
if (length1 < 16 || length2 < 8 || length3 < 8 || length4 < 8)
if (length1 < 8 || length2 < 8 || length3 < 8 || length4 < 8)
return 0;
if (length4 > srcSize) return ERROR(corruption_detected); /* overflow */
}
Expand Down Expand Up @@ -262,11 +265,12 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds
args->bits[2] = HUF_initFastDStream(args->ip[2]);
args->bits[3] = HUF_initFastDStream(args->ip[3]);

/* If ip[] >= ilimit, it is guaranteed to be safe to
* reload bits[]. It may be beyond its section, but is
* guaranteed to be valid (>= istart).
*/
args->ilimit = ilimit;
/* The decoders must be sure to never read beyond ilowest.
* This is lower than iend[0], but allowing decoders to read
* down to ilowest can allow an extra iteration or two in the
* fast loop.
*/
args->ilowest = istart;

args->oend = oend;
args->dt = dt;
Expand All @@ -291,7 +295,7 @@ static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressFastArg
assert(sizeof(size_t) == 8);
bit->bitContainer = MEM_readLEST(args->ip[stream]);
bit->bitsConsumed = ZSTD_countTrailingZeros64(args->bits[stream]);
bit->start = (const char*)args->iend[0];
bit->start = (const char*)args->ilowest;
bit->limitPtr = bit->start + sizeof(size_t);
bit->ptr = (const char*)args->ip[stream];

Expand Down Expand Up @@ -717,7 +721,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
BYTE* op[4];
U16 const* const dtable = (U16 const*)args->dt;
BYTE* const oend = args->oend;
BYTE const* const ilimit = args->ilimit;
BYTE const* const ilowest = args->ilowest;

/* Copy the arguments to local variables */
ZSTD_memcpy(&bits, &args->bits, sizeof(bits));
Expand All @@ -735,7 +739,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
#ifndef NDEBUG
for (stream = 0; stream < 4; ++stream) {
assert(op[stream] <= (stream == 3 ? oend : op[stream + 1]));
assert(ip[stream] >= ilimit);
assert(ip[stream] >= ilowest);
}
#endif
/* Compute olimit */
Expand All @@ -745,7 +749,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
/* Each iteration consumes up to 11 bits * 5 = 55 bits < 7 bytes
* per stream.
*/
size_t const iiters = (size_t)(ip[0] - ilimit) / 7;
size_t const iiters = (size_t)(ip[0] - ilowest) / 7;
/* We can safely run iters iterations before running bounds checks */
size_t const iters = MIN(oiters, iiters);
size_t const symbols = iters * 5;
Expand All @@ -756,8 +760,8 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
*/
olimit = op[3] + symbols;

/* Exit fast decoding loop once we get close to the end. */
if (op[3] + 20 > olimit)
/* Exit fast decoding loop once we reach the end. */
if (op[3] == olimit)
break;

/* Exit the decoding loop if any input pointer has crossed the
Expand Down Expand Up @@ -836,7 +840,7 @@ HUF_decompress4X1_usingDTable_internal_fast(
HUF_DecompressFastLoopFn loopFn)
{
void const* dt = DTable + 1;
const BYTE* const iend = (const BYTE*)cSrc + 6;
BYTE const* const ilowest = (BYTE const*)cSrc;
BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize);
HUF_DecompressFastArgs args;
{ size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable);
Expand All @@ -845,18 +849,22 @@ HUF_decompress4X1_usingDTable_internal_fast(
return 0;
}

assert(args.ip[0] >= args.ilimit);
assert(args.ip[0] >= args.ilowest);
loopFn(&args);

/* Our loop guarantees that ip[] >= ilimit and that we haven't
/* Our loop guarantees that ip[] >= ilowest and that we haven't
* overwritten any op[].
*/
assert(args.ip[0] >= iend);
assert(args.ip[1] >= iend);
assert(args.ip[2] >= iend);
assert(args.ip[3] >= iend);
assert(args.ip[0] >= ilowest);
assert(args.ip[0] >= ilowest);
assert(args.ip[1] >= ilowest);
assert(args.ip[2] >= ilowest);
assert(args.ip[3] >= ilowest);
assert(args.op[3] <= oend);
(void)iend;

assert(ilowest == args.ilowest);
assert(ilowest + 6 == args.iend[0]);
(void)ilowest;

/* finish bit streams one by one. */
{ size_t const segmentSize = (dstSize+3) / 4;
Expand Down Expand Up @@ -1512,7 +1520,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
BYTE* op[4];
BYTE* oend[4];
HUF_DEltX2 const* const dtable = (HUF_DEltX2 const*)args->dt;
BYTE const* const ilimit = args->ilimit;
BYTE const* const ilowest = args->ilowest;

/* Copy the arguments to local registers. */
ZSTD_memcpy(&bits, &args->bits, sizeof(bits));
Expand All @@ -1535,7 +1543,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
#ifndef NDEBUG
for (stream = 0; stream < 4; ++stream) {
assert(op[stream] <= oend[stream]);
assert(ip[stream] >= ilimit);
assert(ip[stream] >= ilowest);
}
#endif
/* Compute olimit */
Expand All @@ -1548,7 +1556,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
* We also know that each input pointer is >= ip[0]. So we can run
* iters loops before running out of input.
*/
size_t iters = (size_t)(ip[0] - ilimit) / 7;
size_t iters = (size_t)(ip[0] - ilowest) / 7;
/* Each iteration can produce up to 10 bytes of output per stream.
* Each output stream my advance at different rates. So take the
* minimum number of safe iterations among all the output streams.
Expand All @@ -1566,8 +1574,8 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
*/
olimit = op[3] + (iters * 5);

/* Exit the fast decoding loop if we are too close to the end. */
if (op[3] + 10 > olimit)
/* Exit the fast decoding loop once we reach the end. */
if (op[3] == olimit)
break;

/* Exit the decoding loop if any input pointer has crossed the
Expand Down Expand Up @@ -1652,7 +1660,7 @@ HUF_decompress4X2_usingDTable_internal_fast(
const HUF_DTable* DTable,
HUF_DecompressFastLoopFn loopFn) {
void const* dt = DTable + 1;
const BYTE* const iend = (const BYTE*)cSrc + 6;
const BYTE* const ilowest = (const BYTE*)cSrc;
BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize);
HUF_DecompressFastArgs args;
{
Expand All @@ -1662,16 +1670,19 @@ HUF_decompress4X2_usingDTable_internal_fast(
return 0;
}

assert(args.ip[0] >= args.ilimit);
assert(args.ip[0] >= args.ilowest);
loopFn(&args);

/* note : op4 already verified within main loop */
assert(args.ip[0] >= iend);
assert(args.ip[1] >= iend);
assert(args.ip[2] >= iend);
assert(args.ip[3] >= iend);
assert(args.ip[0] >= ilowest);
assert(args.ip[1] >= ilowest);
assert(args.ip[2] >= ilowest);
assert(args.ip[3] >= ilowest);
assert(args.op[3] <= oend);
(void)iend;

assert(ilowest == args.ilowest);
assert(ilowest + 6 == args.iend[0]);
(void)ilowest;

/* finish bitStreams one by one */
{
Expand Down
34 changes: 16 additions & 18 deletions lib/decompress/huf_decompress_amd64.S
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop:
movq 88(%rax), %bits3
movq 96(%rax), %dtable
push %rax /* argument */
push 104(%rax) /* ilimit */
push 104(%rax) /* ilowest */
push 112(%rax) /* oend */
push %olimit /* olimit space */

Expand All @@ -156,11 +156,11 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop:
shrq $2, %r15

movq %ip0, %rax /* rax = ip0 */
movq 40(%rsp), %rdx /* rdx = ilimit */
subq %rdx, %rax /* rax = ip0 - ilimit */
movq %rax, %rbx /* rbx = ip0 - ilimit */
movq 40(%rsp), %rdx /* rdx = ilowest */
subq %rdx, %rax /* rax = ip0 - ilowest */
movq %rax, %rbx /* rbx = ip0 - ilowest */

/* rdx = (ip0 - ilimit) / 7 */
/* rdx = (ip0 - ilowest) / 7 */
movabsq $2635249153387078803, %rdx
mulq %rdx
subq %rdx, %rbx
Expand All @@ -183,9 +183,8 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop:

/* If (op3 + 20 > olimit) */
movq %op3, %rax /* rax = op3 */
addq $20, %rax /* rax = op3 + 20 */
cmpq %rax, %olimit /* op3 + 20 > olimit */
jb .L_4X1_exit
cmpq %rax, %olimit /* op3 == olimit */
je .L_4X1_exit

/* If (ip1 < ip0) go to exit */
cmpq %ip0, %ip1
Expand Down Expand Up @@ -316,7 +315,7 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop:
/* Restore stack (oend & olimit) */
pop %rax /* olimit */
pop %rax /* oend */
pop %rax /* ilimit */
pop %rax /* ilowest */
pop %rax /* arg */

/* Save ip / op / bits */
Expand Down Expand Up @@ -387,7 +386,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:
movq 96(%rax), %dtable
push %rax /* argument */
push %rax /* olimit */
push 104(%rax) /* ilimit */
push 104(%rax) /* ilowest */

movq 112(%rax), %rax
push %rax /* oend3 */
Expand All @@ -414,9 +413,9 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:

/* We can consume up to 7 input bytes each iteration. */
movq %ip0, %rax /* rax = ip0 */
movq 40(%rsp), %rdx /* rdx = ilimit */
subq %rdx, %rax /* rax = ip0 - ilimit */
movq %rax, %r15 /* r15 = ip0 - ilimit */
movq 40(%rsp), %rdx /* rdx = ilowest */
subq %rdx, %rax /* rax = ip0 - ilowest */
movq %rax, %r15 /* r15 = ip0 - ilowest */

/* rdx = rax / 7 */
movabsq $2635249153387078803, %rdx
Expand All @@ -426,7 +425,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:
addq %r15, %rdx
shrq $2, %rdx

/* r15 = (ip0 - ilimit) / 7 */
/* r15 = (ip0 - ilowest) / 7 */
movq %rdx, %r15

/* r15 = min(r15, min(oend0 - op0, oend1 - op1, oend2 - op2, oend3 - op3) / 10) */
Expand Down Expand Up @@ -467,9 +466,8 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:

/* If (op3 + 10 > olimit) */
movq %op3, %rax /* rax = op3 */
addq $10, %rax /* rax = op3 + 10 */
cmpq %rax, %olimit /* op3 + 10 > olimit */
jb .L_4X2_exit
cmpq %rax, %olimit /* op3 == olimit */
je .L_4X2_exit

/* If (ip1 < ip0) go to exit */
cmpq %ip0, %ip1
Expand Down Expand Up @@ -537,7 +535,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:
pop %rax /* oend1 */
pop %rax /* oend2 */
pop %rax /* oend3 */
pop %rax /* ilimit */
pop %rax /* ilowest */
pop %rax /* olimit */
pop %rax /* arg */

Expand Down