Skip to content

Commit

Permalink
Merge pull request facebook#2771 from facebook/opt_investigation
Browse files Browse the repository at this point in the history
Improve optimal parser performance on small data
  • Loading branch information
Cyan4973 authored Sep 14, 2021
2 parents d22bbed + fd94b9d commit 2e6f5bc
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 246 deletions.
121 changes: 64 additions & 57 deletions lib/compress/zstd_opt.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


#define ZSTD_LITFREQ_ADD 2 /* scaling factor for litFreq, so that frequencies adapt faster to new stats */
#define ZSTD_FREQ_DIV 4 /* log factor when using previous stats to init next stats */
#define ZSTD_MAX_PRICE (1<<30)

#define ZSTD_PREDEF_THRESHOLD 1024 /* if srcSize < ZSTD_PREDEF_THRESHOLD, symbols' cost is assumed static, directly determined by pre-defined distributions */
Expand All @@ -24,11 +23,11 @@
* Price functions for optimal parser
***************************************/

#if 0 /* approximation at bit level */
#if 0 /* approximation at bit level (for tests) */
# define BITCOST_ACCURACY 0
# define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
# define WEIGHT(stat) ((void)opt, ZSTD_bitWeight(stat))
#elif 0 /* fractional bit accuracy */
# define WEIGHT(stat, opt) ((void)opt, ZSTD_bitWeight(stat))
#elif 0 /* fractional bit accuracy (for tests) */
# define BITCOST_ACCURACY 8
# define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
# define WEIGHT(stat,opt) ((void)opt, ZSTD_fracWeight(stat))
Expand Down Expand Up @@ -79,25 +78,46 @@ static void ZSTD_setBasePrices(optState_t* optPtr, int optLevel)
}


/* ZSTD_downscaleStat() :
* reduce all elements in table by a factor 2^(ZSTD_FREQ_DIV+malus)
* return the resulting sum of elements */
static U32 ZSTD_downscaleStat(unsigned* table, U32 lastEltIndex, int malus)
static U32 sum_u32(const unsigned table[], size_t nbElts)
{
size_t n;
U32 total = 0;
for (n=0; n<nbElts; n++) {
total += table[n];
}
return total;
}

static U32 ZSTD_downscaleStats(unsigned* table, U32 lastEltIndex, U32 shift)
{
U32 s, sum=0;
DEBUGLOG(5, "ZSTD_downscaleStat (nbElts=%u)", (unsigned)lastEltIndex+1);
assert(ZSTD_FREQ_DIV+malus > 0 && ZSTD_FREQ_DIV+malus < 31);
DEBUGLOG(5, "ZSTD_downscaleStats (nbElts=%u, shift=%u)", (unsigned)lastEltIndex+1, (unsigned)shift);
assert(shift < 30);
for (s=0; s<lastEltIndex+1; s++) {
table[s] = 1 + (table[s] >> (ZSTD_FREQ_DIV+malus));
table[s] = 1 + (table[s] >> shift);
sum += table[s];
}
return sum;
}

/* ZSTD_scaleStats() :
* reduce all elements in table is sum too large
* return the resulting sum of elements */
static U32 ZSTD_scaleStats(unsigned* table, U32 lastEltIndex, U32 logTarget)
{
U32 const prevsum = sum_u32(table, lastEltIndex+1);
U32 const factor = prevsum >> logTarget;
DEBUGLOG(5, "ZSTD_scaleStats (nbElts=%u, target=%u)", (unsigned)lastEltIndex+1, (unsigned)logTarget);
assert(logTarget < 30);
if (factor <= 1) return prevsum;
return ZSTD_downscaleStats(table, lastEltIndex, ZSTD_highbit32(factor));
}

/* ZSTD_rescaleFreqs() :
* if first block (detected by optPtr->litLengthSum == 0) : init statistics
* take hints from dictionary if there is one
* or init from zero, using src for literals stats, or flat 1 for match symbols
* and init from zero if there is none,
* using src for literals stats, and baseline stats for sequence symbols
* otherwise downscale existing stats, to be used as seed for next block.
*/
static void
Expand Down Expand Up @@ -174,36 +194,44 @@ ZSTD_rescaleFreqs(optState_t* const optPtr,
if (compressedLiterals) {
unsigned lit = MaxLit;
HIST_count_simple(optPtr->litFreq, &lit, src, srcSize); /* use raw first block to init statistics */
optPtr->litSum = ZSTD_downscaleStat(optPtr->litFreq, MaxLit, 1);
optPtr->litSum = ZSTD_downscaleStats(optPtr->litFreq, MaxLit, 8);
}

{ unsigned ll;
for (ll=0; ll<=MaxLL; ll++)
optPtr->litLengthFreq[ll] = 1;
{ unsigned const baseLLfreqs[MaxLL+1] = {
4, 2, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1
};
ZSTD_memcpy(optPtr->litLengthFreq, baseLLfreqs, sizeof(baseLLfreqs)); optPtr->litLengthSum = sum_u32(baseLLfreqs, MaxLL+1);
}
optPtr->litLengthSum = MaxLL+1;

{ unsigned ml;
for (ml=0; ml<=MaxML; ml++)
optPtr->matchLengthFreq[ml] = 1;
}
optPtr->matchLengthSum = MaxML+1;

{ unsigned of;
for (of=0; of<=MaxOff; of++)
optPtr->offCodeFreq[of] = 1;
{ unsigned const baseOFCfreqs[MaxOff+1] = {
6, 2, 1, 1, 2, 3, 4, 4,
4, 3, 2, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1
};
ZSTD_memcpy(optPtr->offCodeFreq, baseOFCfreqs, sizeof(baseOFCfreqs)); optPtr->offCodeSum = sum_u32(baseOFCfreqs, MaxOff+1);
}
optPtr->offCodeSum = MaxOff+1;


}

} else { /* new block : re-use previous statistics, scaled down */

if (compressedLiterals)
optPtr->litSum = ZSTD_downscaleStat(optPtr->litFreq, MaxLit, 1);
optPtr->litLengthSum = ZSTD_downscaleStat(optPtr->litLengthFreq, MaxLL, 0);
optPtr->matchLengthSum = ZSTD_downscaleStat(optPtr->matchLengthFreq, MaxML, 0);
optPtr->offCodeSum = ZSTD_downscaleStat(optPtr->offCodeFreq, MaxOff, 0);
optPtr->litSum = ZSTD_scaleStats(optPtr->litFreq, MaxLit, 12);
optPtr->litLengthSum = ZSTD_scaleStats(optPtr->litLengthFreq, MaxLL, 11);
optPtr->matchLengthSum = ZSTD_scaleStats(optPtr->matchLengthFreq, MaxML, 11);
optPtr->offCodeSum = ZSTD_scaleStats(optPtr->offCodeFreq, MaxOff, 11);
}

ZSTD_setBasePrices(optPtr, optLevel);
Expand Down Expand Up @@ -901,11 +929,11 @@ static void ZSTD_optLdm_processMatchCandidate(ZSTD_optLdm_t* optLdm, ZSTD_match_
ZSTD_optLdm_maybeAddMatch(matches, nbMatches, optLdm, currPosInBlock);
}


/*-*******************************
* Optimal parser
*********************************/


static U32 ZSTD_totalLen(ZSTD_optimal_t sol)
{
return sol.litlen + sol.mlen;
Expand Down Expand Up @@ -987,7 +1015,7 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
* in every price. We include the literal length to avoid negative
* prices when we subtract the previous literal length.
*/
opt[0].price = ZSTD_litLengthPrice(litlen, optStatePtr, optLevel);
opt[0].price = (int)ZSTD_litLengthPrice(litlen, optStatePtr, optLevel);

/* large match -> immediate encoding */
{ U32 const maxML = matches[nbMatches-1].len;
Expand All @@ -1007,7 +1035,8 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
} }

/* set prices for first matches starting position == 0 */
{ U32 const literalsPrice = opt[0].price + ZSTD_litLengthPrice(0, optStatePtr, optLevel);
assert(opt[0].price >= 0);
{ U32 const literalsPrice = (U32)opt[0].price + ZSTD_litLengthPrice(0, optStatePtr, optLevel);
U32 pos;
U32 matchNb;
for (pos = 1; pos < minMatch; pos++) {
Expand All @@ -1024,7 +1053,7 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
opt[pos].mlen = pos;
opt[pos].off = offset;
opt[pos].litlen = litlen;
opt[pos].price = sequencePrice;
opt[pos].price = (int)sequencePrice;
} }
last_pos = pos-1;
}
Expand All @@ -1039,9 +1068,9 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
/* Fix current position with one literal if cheaper */
{ U32 const litlen = (opt[cur-1].mlen == 0) ? opt[cur-1].litlen + 1 : 1;
int const price = opt[cur-1].price
+ ZSTD_rawLiteralsCost(ip+cur-1, 1, optStatePtr, optLevel)
+ ZSTD_litLengthPrice(litlen, optStatePtr, optLevel)
- ZSTD_litLengthPrice(litlen-1, optStatePtr, optLevel);
+ (int)ZSTD_rawLiteralsCost(ip+cur-1, 1, optStatePtr, optLevel)
+ (int)ZSTD_litLengthPrice(litlen, optStatePtr, optLevel)
- (int)ZSTD_litLengthPrice(litlen-1, optStatePtr, optLevel);
assert(price < 1000000000); /* overflow check */
if (price <= opt[cur].price) {
DEBUGLOG(7, "cPos:%zi==rPos:%u : better price (%.2f<=%.2f) using literal (ll==%u) (hist:%u,%u,%u)",
Expand Down Expand Up @@ -1084,9 +1113,10 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
continue; /* skip unpromising positions; about ~+6% speed, -0.01 ratio */
}

assert(opt[cur].price >= 0);
{ U32 const ll0 = (opt[cur].mlen != 0);
U32 const litlen = (opt[cur].mlen == 0) ? opt[cur].litlen : 0;
U32 const previousPrice = opt[cur].price;
U32 const previousPrice = (U32)opt[cur].price;
U32 const basePrice = previousPrice + ZSTD_litLengthPrice(0, optStatePtr, optLevel);
U32 nbMatches = ZSTD_BtGetAllMatches(matches, ms, &nextToUpdate3, inr, iend, dictMode, opt[cur].rep, ll0, minMatch);
U32 matchNb;
Expand Down Expand Up @@ -1126,7 +1156,7 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,

for (mlen = lastML; mlen >= startML; mlen--) { /* scan downward */
U32 const pos = cur + mlen;
int const price = basePrice + ZSTD_getMatchPrice(offset, mlen, optStatePtr, optLevel);
int const price = (int)basePrice + (int)ZSTD_getMatchPrice(offset, mlen, optStatePtr, optLevel);

if ((pos > last_pos) || (price < opt[pos].price)) {
DEBUGLOG(7, "rPos:%u (ml=%2u) => new better price (%.2f<%.2f)",
Expand Down Expand Up @@ -1222,28 +1252,7 @@ size_t ZSTD_compressBlock_btopt(
}


/* used in 2-pass strategy */
static U32 ZSTD_upscaleStat(unsigned* table, U32 lastEltIndex, int bonus)
{
U32 s, sum=0;
assert(ZSTD_FREQ_DIV+bonus >= 0);
for (s=0; s<lastEltIndex+1; s++) {
table[s] <<= ZSTD_FREQ_DIV+bonus;
table[s]--;
sum += table[s];
}
return sum;
}

/* used in 2-pass strategy */
MEM_STATIC void ZSTD_upscaleStats(optState_t* optPtr)
{
if (ZSTD_compressedLiterals(optPtr))
optPtr->litSum = ZSTD_upscaleStat(optPtr->litFreq, MaxLit, 0);
optPtr->litLengthSum = ZSTD_upscaleStat(optPtr->litLengthFreq, MaxLL, 0);
optPtr->matchLengthSum = ZSTD_upscaleStat(optPtr->matchLengthFreq, MaxML, 0);
optPtr->offCodeSum = ZSTD_upscaleStat(optPtr->offCodeFreq, MaxOff, 0);
}

/* ZSTD_initStats_ultra():
* make a first compression pass, just to seed stats with more accurate starting values.
Expand Down Expand Up @@ -1274,8 +1283,6 @@ ZSTD_initStats_ultra(ZSTD_matchState_t* ms,
ms->window.lowLimit = ms->window.dictLimit;
ms->nextToUpdate = ms->window.dictLimit;

/* re-inforce weight of collected statistics */
ZSTD_upscaleStats(&ms->opt);
}

size_t ZSTD_compressBlock_btultra(
Expand Down
4 changes: 2 additions & 2 deletions lib/dictBuilder/cover.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
* Console display
***************************************/
#ifndef LOCALDISPLAYLEVEL
static int g_displayLevel = 2;
static int g_displayLevel = 0;
#endif
#undef DISPLAY
#define DISPLAY(...) \
Expand Down Expand Up @@ -735,7 +735,7 @@ ZDICTLIB_API size_t ZDICT_trainFromBuffer_cover(
COVER_map_t activeDmers;
parameters.splitPoint = 1.0;
/* Initialize global data */
g_displayLevel = parameters.zParams.notificationLevel;
g_displayLevel = (int)parameters.zParams.notificationLevel;
/* Checks */
if (!COVER_checkParameters(parameters, dictBufferCapacity)) {
DISPLAYLEVEL(1, "Cover parameters incorrect\n");
Expand Down
8 changes: 4 additions & 4 deletions lib/dictBuilder/fastcover.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
* Console display
***************************************/
#ifndef LOCALDISPLAYLEVEL
static int g_displayLevel = 2;
static int g_displayLevel = 0;
#endif
#undef DISPLAY
#define DISPLAY(...) \
Expand Down Expand Up @@ -549,7 +549,7 @@ ZDICT_trainFromBuffer_fastCover(void* dictBuffer, size_t dictBufferCapacity,
ZDICT_cover_params_t coverParams;
FASTCOVER_accel_t accelParams;
/* Initialize global data */
g_displayLevel = parameters.zParams.notificationLevel;
g_displayLevel = (int)parameters.zParams.notificationLevel;
/* Assign splitPoint and f if not provided */
parameters.splitPoint = 1.0;
parameters.f = parameters.f == 0 ? DEFAULT_F : parameters.f;
Expand Down Expand Up @@ -632,7 +632,7 @@ ZDICT_optimizeTrainFromBuffer_fastCover(
const unsigned accel = parameters->accel == 0 ? DEFAULT_ACCEL : parameters->accel;
const unsigned shrinkDict = 0;
/* Local variables */
const int displayLevel = parameters->zParams.notificationLevel;
const int displayLevel = (int)parameters->zParams.notificationLevel;
unsigned iteration = 1;
unsigned d;
unsigned k;
Expand Down Expand Up @@ -716,7 +716,7 @@ ZDICT_optimizeTrainFromBuffer_fastCover(
data->parameters.splitPoint = splitPoint;
data->parameters.steps = kSteps;
data->parameters.shrinkDict = shrinkDict;
data->parameters.zParams.notificationLevel = g_displayLevel;
data->parameters.zParams.notificationLevel = (unsigned)g_displayLevel;
/* Check the parameters */
if (!FASTCOVER_checkParameters(data->parameters, dictBufferCapacity,
data->ctx->f, accel)) {
Expand Down
25 changes: 16 additions & 9 deletions lib/dictBuilder/zdict.c
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ static unsigned ZDICT_NbCommonBytes (size_t val)
_BitScanForward64( &r, (U64)val );
return (unsigned)(r>>3);
# elif defined(__GNUC__) && (__GNUC__ >= 3)
return (__builtin_ctzll((U64)val) >> 3);
return (unsigned)(__builtin_ctzll((U64)val) >> 3);
# else
static const int DeBruijnBytePos[64] = { 0, 0, 0, 0, 0, 1, 1, 2, 0, 3, 1, 3, 1, 4, 2, 7, 0, 2, 3, 6, 1, 5, 3, 5, 1, 3, 4, 4, 2, 5, 6, 7, 7, 0, 1, 2, 3, 3, 4, 6, 2, 6, 5, 5, 3, 4, 5, 6, 7, 1, 2, 4, 6, 4, 4, 5, 7, 2, 6, 5, 7, 6, 7, 7 };
return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58];
Expand All @@ -150,7 +150,7 @@ static unsigned ZDICT_NbCommonBytes (size_t val)
_BitScanForward( &r, (U32)val );
return (unsigned)(r>>3);
# elif defined(__GNUC__) && (__GNUC__ >= 3)
return (__builtin_ctz((U32)val) >> 3);
return (unsigned)(__builtin_ctz((U32)val) >> 3);
# else
static const int DeBruijnBytePos[32] = { 0, 0, 3, 0, 3, 1, 3, 0, 3, 2, 2, 1, 3, 2, 0, 1, 3, 3, 1, 2, 2, 2, 2, 0, 3, 1, 2, 0, 1, 0, 1, 1 };
return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27];
Expand All @@ -163,7 +163,7 @@ static unsigned ZDICT_NbCommonBytes (size_t val)
_BitScanReverse64( &r, val );
return (unsigned)(r>>3);
# elif defined(__GNUC__) && (__GNUC__ >= 3)
return (__builtin_clzll(val) >> 3);
return (unsigned)(__builtin_clzll(val) >> 3);
# else
unsigned r;
const unsigned n32 = sizeof(size_t)*4; /* calculate this way due to compiler complaining in 32-bits mode */
Expand All @@ -178,7 +178,7 @@ static unsigned ZDICT_NbCommonBytes (size_t val)
_BitScanReverse( &r, (unsigned long)val );
return (unsigned)(r>>3);
# elif defined(__GNUC__) && (__GNUC__ >= 3)
return (__builtin_clz((U32)val) >> 3);
return (unsigned)(__builtin_clz((U32)val) >> 3);
# else
unsigned r;
if (!(val>>16)) { r=2; val>>=8; } else { r=0; val>>=24; }
Expand Down Expand Up @@ -235,7 +235,7 @@ static dictItem ZDICT_analyzePos(
U32 savings[LLIMIT] = {0};
const BYTE* b = (const BYTE*)buffer;
size_t maxLength = LLIMIT;
size_t pos = suffix[start];
size_t pos = (size_t)suffix[start];
U32 end = start;
dictItem solution;

Expand Down Expand Up @@ -369,7 +369,7 @@ static dictItem ZDICT_analyzePos(
savings[i] = savings[i-1] + (lengthList[i] * (i-3));

DISPLAYLEVEL(4, "Selected dict at position %u, of length %u : saves %u (ratio: %.2f) \n",
(unsigned)pos, (unsigned)maxLength, (unsigned)savings[maxLength], (double)savings[maxLength] / maxLength);
(unsigned)pos, (unsigned)maxLength, (unsigned)savings[maxLength], (double)savings[maxLength] / (double)maxLength);

solution.pos = (U32)pos;
solution.length = (U32)maxLength;
Expand All @@ -379,7 +379,7 @@ static dictItem ZDICT_analyzePos(
{ U32 id;
for (id=start; id<end; id++) {
U32 p, pEnd, length;
U32 const testedPos = suffix[id];
U32 const testedPos = (U32)suffix[id];
if (testedPos == pos)
length = solution.length;
else {
Expand Down Expand Up @@ -442,7 +442,7 @@ static U32 ZDICT_tryMerge(dictItem* table, dictItem elt, U32 eltNbToSkip, const

if ((table[u].pos + table[u].length >= elt.pos) && (table[u].pos < elt.pos)) { /* overlap, existing < new */
/* append */
int const addedLength = (int)eltEnd - (table[u].pos + table[u].length);
int const addedLength = (int)eltEnd - (int)(table[u].pos + table[u].length);
table[u].savings += elt.length / 8; /* rough approx bonus */
if (addedLength > 0) { /* otherwise, elt fully included into existing */
table[u].length += addedLength;
Expand Down Expand Up @@ -766,6 +766,13 @@ static size_t ZDICT_analyzeEntropy(void* dstBuffer, size_t maxDstSize,
pos += fileSizes[u];
}

if (notificationLevel >= 4) {
/* writeStats */
DISPLAYLEVEL(4, "Offset Code Frequencies : \n");
for (u=0; u<=offcodeMax; u++) {
DISPLAYLEVEL(4, "%2u :%7u \n", u, offcodeCount[u]);
} }

/* analyze, build stats, starting with literals */
{ size_t maxNbBits = HUF_buildCTable (hufTable, countLit, 255, huffLog);
if (HUF_isError(maxNbBits)) {
Expand Down Expand Up @@ -872,7 +879,7 @@ static size_t ZDICT_analyzeEntropy(void* dstBuffer, size_t maxDstSize,
MEM_writeLE32(dstPtr+8, bestRepOffset[2].offset);
#else
/* at this stage, we don't use the result of "most common first offset",
as the impact of statistics is not properly evaluated */
* as the impact of statistics is not properly evaluated */
MEM_writeLE32(dstPtr+0, repStartValue[0]);
MEM_writeLE32(dstPtr+4, repStartValue[1]);
MEM_writeLE32(dstPtr+8, repStartValue[2]);
Expand Down
Loading

0 comments on commit 2e6f5bc

Please sign in to comment.