Skip to content

Commit

Permalink
zstd decode: Use compound decSymbol (#144)
Browse files Browse the repository at this point in the history
* zstd decode: Use compound decSymbol

The compiler is not generating very nice code when decSymbol is separate values.
Use a compound value instead of a struct. This will allow all values to be transferred in a single register.

```
name                                     old time/op    new time/op    delta
Decoder_DecodeAll/kppkn.gtb.zst-12          610µs ± 1%     599µs ± 1%  -1.84%        (p=0.000 n=10+10)
Decoder_DecodeAll/geo.protodata.zst-12      138µs ± 2%     136µs ± 0%  -1.72%         (p=0.000 n=10+9)
Decoder_DecodeAll/plrabn12.txt.zst-12      1.95ms ± 1%    1.90ms ± 0%  -2.19%         (p=0.000 n=10+7)
Decoder_DecodeAll/lcet10.txt.zst-12        1.46ms ± 2%    1.42ms ± 1%  -2.48%        (p=0.000 n=10+10)
Decoder_DecodeAll/asyoulik.txt.zst-12       506µs ± 1%     497µs ± 1%  -1.83%         (p=0.000 n=10+9)
Decoder_DecodeAll/alice29.txt.zst-12        655µs ± 1%     636µs ± 1%  -2.97%        (p=0.000 n=10+10)
Decoder_DecodeAll/html_x_4.zst-12           267µs ± 1%     261µs ± 1%  -2.43%         (p=0.000 n=10+9)
Decoder_DecodeAll/paper-100k.pdf.zst-12    25.0µs ± 1%    24.3µs ± 1%  -2.61%        (p=0.000 n=10+10)
Decoder_DecodeAll/fireworks.jpeg.zst-12    9.65µs ± 1%    9.61µs ± 0%    ~            (p=0.250 n=10+9)
Decoder_DecodeAll/urls.10K.zst-12          1.67ms ± 1%    1.63ms ± 2%  -2.29%        (p=0.000 n=10+10)
Decoder_DecodeAll/html.zst-12               156µs ± 1%     155µs ± 1%  -0.85%        (p=0.014 n=10+10)

name                                     old speed      new speed      delta
Decoder_DecodeAll/kppkn.gtb.zst-12        302MB/s ± 1%   308MB/s ± 1%  +1.88%        (p=0.000 n=10+10)
Decoder_DecodeAll/geo.protodata.zst-12    860MB/s ± 2%   875MB/s ± 0%  +1.75%         (p=0.000 n=10+9)
Decoder_DecodeAll/plrabn12.txt.zst-12     248MB/s ± 1%   253MB/s ± 0%  +2.24%         (p=0.000 n=10+7)
Decoder_DecodeAll/lcet10.txt.zst-12       293MB/s ± 2%   300MB/s ± 1%  +2.54%        (p=0.000 n=10+10)
Decoder_DecodeAll/asyoulik.txt.zst-12     247MB/s ± 1%   252MB/s ± 1%  +1.86%         (p=0.000 n=10+9)
Decoder_DecodeAll/alice29.txt.zst-12      232MB/s ± 1%   239MB/s ± 1%  +3.06%        (p=0.000 n=10+10)
Decoder_DecodeAll/html_x_4.zst-12        1.53GB/s ± 1%  1.57GB/s ± 1%  +2.49%         (p=0.000 n=10+9)
Decoder_DecodeAll/paper-100k.pdf.zst-12  4.10GB/s ± 1%  4.21GB/s ± 1%  +2.68%        (p=0.000 n=10+10)
Decoder_DecodeAll/fireworks.jpeg.zst-12  12.8GB/s ± 1%  12.8GB/s ± 0%    ~            (p=0.286 n=10+9)
Decoder_DecodeAll/urls.10K.zst-12         420MB/s ± 1%   430MB/s ± 2%  +2.35%        (p=0.000 n=10+10)
Decoder_DecodeAll/html.zst-12             655MB/s ± 1%   661MB/s ± 1%  +0.86%        (p=0.015 n=10+10)
```
  • Loading branch information
klauspost committed Aug 3, 2019
1 parent 1a36bca commit 0e54620
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 168 deletions.
167 changes: 87 additions & 80 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -966,103 +966,110 @@ func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder) {
// We don't predefine them, since this also tests our transformations.
// Reference from here: https://github.com/facebook/zstd/blob/ededcfca57366461021c922720878c81a5854a0a/lib/decompress/zstd_decompress_block.c#L234
func TestPredefTables(t *testing.T) {
x := func(nextState uint16, nbAddBits, nbBits uint8, baseVal uint32) decSymbol {
return newDecSymbol(nbBits, nbAddBits, nextState, baseVal)
}
for i := range fsePredef[:] {
var want []decSymbol
switch tableIndex(i) {
case tableLiteralLengths:
want = []decSymbol{
/* nextState, nbAddBits, nbBits, baseVal */
{0, 0, 4, 0}, {16, 0, 4, 0},
{32, 0, 5, 1}, {0, 0, 5, 3},
{0, 0, 5, 4}, {0, 0, 5, 6},
{0, 0, 5, 7}, {0, 0, 5, 9},
{0, 0, 5, 10}, {0, 0, 5, 12},
{0, 0, 6, 14}, {0, 1, 5, 16},
{0, 1, 5, 20}, {0, 1, 5, 22},
{0, 2, 5, 28}, {0, 3, 5, 32},
{0, 4, 5, 48}, {32, 6, 5, 64},
{0, 7, 5, 128}, {0, 8, 6, 256},
{0, 10, 6, 1024}, {0, 12, 6, 4096},
{32, 0, 4, 0}, {0, 0, 4, 1},
{0, 0, 5, 2}, {32, 0, 5, 4},
{0, 0, 5, 5}, {32, 0, 5, 7},
{0, 0, 5, 8}, {32, 0, 5, 10},
{0, 0, 5, 11}, {0, 0, 6, 13},
{32, 1, 5, 16}, {0, 1, 5, 18},
{32, 1, 5, 22}, {0, 2, 5, 24},
{32, 3, 5, 32}, {0, 3, 5, 40},
{0, 6, 4, 64}, {16, 6, 4, 64},
{32, 7, 5, 128}, {0, 9, 6, 512},
{0, 11, 6, 2048}, {48, 0, 4, 0},
{16, 0, 4, 1}, {32, 0, 5, 2},
{32, 0, 5, 3}, {32, 0, 5, 5},
{32, 0, 5, 6}, {32, 0, 5, 8},
{32, 0, 5, 9}, {32, 0, 5, 11},
{32, 0, 5, 12}, {0, 0, 6, 15},
{32, 1, 5, 18}, {32, 1, 5, 20},
{32, 2, 5, 24}, {32, 2, 5, 28},
{32, 3, 5, 40}, {32, 4, 5, 48},
{0, 16, 6, 65536}, {0, 15, 6, 32768},
{0, 14, 6, 16384}, {0, 13, 6, 8192}}
x(0, 0, 4, 0), x(16, 0, 4, 0),
x(32, 0, 5, 1), x(0, 0, 5, 3),
x(0, 0, 5, 4), x(0, 0, 5, 6),
x(0, 0, 5, 7), x(0, 0, 5, 9),
x(0, 0, 5, 10), x(0, 0, 5, 12),
x(0, 0, 6, 14), x(0, 1, 5, 16),
x(0, 1, 5, 20), x(0, 1, 5, 22),
x(0, 2, 5, 28), x(0, 3, 5, 32),
x(0, 4, 5, 48), x(32, 6, 5, 64),
x(0, 7, 5, 128), x(0, 8, 6, 256),
x(0, 10, 6, 1024), x(0, 12, 6, 4096),
x(32, 0, 4, 0), x(0, 0, 4, 1),
x(0, 0, 5, 2), x(32, 0, 5, 4),
x(0, 0, 5, 5), x(32, 0, 5, 7),
x(0, 0, 5, 8), x(32, 0, 5, 10),
x(0, 0, 5, 11), x(0, 0, 6, 13),
x(32, 1, 5, 16), x(0, 1, 5, 18),
x(32, 1, 5, 22), x(0, 2, 5, 24),
x(32, 3, 5, 32), x(0, 3, 5, 40),
x(0, 6, 4, 64), x(16, 6, 4, 64),
x(32, 7, 5, 128), x(0, 9, 6, 512),
x(0, 11, 6, 2048), x(48, 0, 4, 0),
x(16, 0, 4, 1), x(32, 0, 5, 2),
x(32, 0, 5, 3), x(32, 0, 5, 5),
x(32, 0, 5, 6), x(32, 0, 5, 8),
x(32, 0, 5, 9), x(32, 0, 5, 11),
x(32, 0, 5, 12), x(0, 0, 6, 15),
x(32, 1, 5, 18), x(32, 1, 5, 20),
x(32, 2, 5, 24), x(32, 2, 5, 28),
x(32, 3, 5, 40), x(32, 4, 5, 48),
x(0, 16, 6, 65536), x(0, 15, 6, 32768),
x(0, 14, 6, 16384), x(0, 13, 6, 8192),
}
case tableOffsets:
want = []decSymbol{
/* nextState, nbAddBits, nbBits, baseVal */
{0, 0, 5, 0}, {0, 6, 4, 61},
{0, 9, 5, 509}, {0, 15, 5, 32765},
{0, 21, 5, 2097149}, {0, 3, 5, 5},
{0, 7, 4, 125}, {0, 12, 5, 4093},
{0, 18, 5, 262141}, {0, 23, 5, 8388605},
{0, 5, 5, 29}, {0, 8, 4, 253},
{0, 14, 5, 16381}, {0, 20, 5, 1048573},
{0, 2, 5, 1}, {16, 7, 4, 125},
{0, 11, 5, 2045}, {0, 17, 5, 131069},
{0, 22, 5, 4194301}, {0, 4, 5, 13},
{16, 8, 4, 253}, {0, 13, 5, 8189},
{0, 19, 5, 524285}, {0, 1, 5, 1},
{16, 6, 4, 61}, {0, 10, 5, 1021},
{0, 16, 5, 65533}, {0, 28, 5, 268435453},
{0, 27, 5, 134217725}, {0, 26, 5, 67108861},
{0, 25, 5, 33554429}, {0, 24, 5, 16777213}}
x(0, 0, 5, 0), x(0, 6, 4, 61),
x(0, 9, 5, 509), x(0, 15, 5, 32765),
x(0, 21, 5, 2097149), x(0, 3, 5, 5),
x(0, 7, 4, 125), x(0, 12, 5, 4093),
x(0, 18, 5, 262141), x(0, 23, 5, 8388605),
x(0, 5, 5, 29), x(0, 8, 4, 253),
x(0, 14, 5, 16381), x(0, 20, 5, 1048573),
x(0, 2, 5, 1), x(16, 7, 4, 125),
x(0, 11, 5, 2045), x(0, 17, 5, 131069),
x(0, 22, 5, 4194301), x(0, 4, 5, 13),
x(16, 8, 4, 253), x(0, 13, 5, 8189),
x(0, 19, 5, 524285), x(0, 1, 5, 1),
x(16, 6, 4, 61), x(0, 10, 5, 1021),
x(0, 16, 5, 65533), x(0, 28, 5, 268435453),
x(0, 27, 5, 134217725), x(0, 26, 5, 67108861),
x(0, 25, 5, 33554429), x(0, 24, 5, 16777213),
}
case tableMatchLengths:
want = []decSymbol{
/* nextState, nbAddBits, nbBits, baseVal */
{0, 0, 6, 3}, {0, 0, 4, 4},
{32, 0, 5, 5}, {0, 0, 5, 6},
{0, 0, 5, 8}, {0, 0, 5, 9},
{0, 0, 5, 11}, {0, 0, 6, 13},
{0, 0, 6, 16}, {0, 0, 6, 19},
{0, 0, 6, 22}, {0, 0, 6, 25},
{0, 0, 6, 28}, {0, 0, 6, 31},
{0, 0, 6, 34}, {0, 1, 6, 37},
{0, 1, 6, 41}, {0, 2, 6, 47},
{0, 3, 6, 59}, {0, 4, 6, 83},
{0, 7, 6, 131}, {0, 9, 6, 515},
{16, 0, 4, 4}, {0, 0, 4, 5},
{32, 0, 5, 6}, {0, 0, 5, 7},
{32, 0, 5, 9}, {0, 0, 5, 10},
{0, 0, 6, 12}, {0, 0, 6, 15},
{0, 0, 6, 18}, {0, 0, 6, 21},
{0, 0, 6, 24}, {0, 0, 6, 27},
{0, 0, 6, 30}, {0, 0, 6, 33},
{0, 1, 6, 35}, {0, 1, 6, 39},
{0, 2, 6, 43}, {0, 3, 6, 51},
{0, 4, 6, 67}, {0, 5, 6, 99},
{0, 8, 6, 259}, {32, 0, 4, 4},
{48, 0, 4, 4}, {16, 0, 4, 5},
{32, 0, 5, 7}, {32, 0, 5, 8},
{32, 0, 5, 10}, {32, 0, 5, 11},
{0, 0, 6, 14}, {0, 0, 6, 17},
{0, 0, 6, 20}, {0, 0, 6, 23},
{0, 0, 6, 26}, {0, 0, 6, 29},
{0, 0, 6, 32}, {0, 16, 6, 65539},
{0, 15, 6, 32771}, {0, 14, 6, 16387},
{0, 13, 6, 8195}, {0, 12, 6, 4099},
{0, 11, 6, 2051}, {0, 10, 6, 1027},
x(0, 0, 6, 3), x(0, 0, 4, 4),
x(32, 0, 5, 5), x(0, 0, 5, 6),
x(0, 0, 5, 8), x(0, 0, 5, 9),
x(0, 0, 5, 11), x(0, 0, 6, 13),
x(0, 0, 6, 16), x(0, 0, 6, 19),
x(0, 0, 6, 22), x(0, 0, 6, 25),
x(0, 0, 6, 28), x(0, 0, 6, 31),
x(0, 0, 6, 34), x(0, 1, 6, 37),
x(0, 1, 6, 41), x(0, 2, 6, 47),
x(0, 3, 6, 59), x(0, 4, 6, 83),
x(0, 7, 6, 131), x(0, 9, 6, 515),
x(16, 0, 4, 4), x(0, 0, 4, 5),
x(32, 0, 5, 6), x(0, 0, 5, 7),
x(32, 0, 5, 9), x(0, 0, 5, 10),
x(0, 0, 6, 12), x(0, 0, 6, 15),
x(0, 0, 6, 18), x(0, 0, 6, 21),
x(0, 0, 6, 24), x(0, 0, 6, 27),
x(0, 0, 6, 30), x(0, 0, 6, 33),
x(0, 1, 6, 35), x(0, 1, 6, 39),
x(0, 2, 6, 43), x(0, 3, 6, 51),
x(0, 4, 6, 67), x(0, 5, 6, 99),
x(0, 8, 6, 259), x(32, 0, 4, 4),
x(48, 0, 4, 4), x(16, 0, 4, 5),
x(32, 0, 5, 7), x(32, 0, 5, 8),
x(32, 0, 5, 10), x(32, 0, 5, 11),
x(0, 0, 6, 14), x(0, 0, 6, 17),
x(0, 0, 6, 20), x(0, 0, 6, 23),
x(0, 0, 6, 26), x(0, 0, 6, 29),
x(0, 0, 6, 32), x(0, 16, 6, 65539),
x(0, 15, 6, 32771), x(0, 14, 6, 16387),
x(0, 13, 6, 8195), x(0, 12, 6, 4099),
x(0, 11, 6, 2051), x(0, 10, 6, 1027),
}
}
pre := fsePredef[i]
got := pre.dt[:1<<pre.actualTableLog]
if !reflect.DeepEqual(got, want) {
t.Logf("want: %v", want)
t.Logf("got : %v", got)
t.Errorf("Predefined table %d incorrect, len(got) = %d, len(want) = %d", i, len(got), len(want))
}
}
Expand Down
113 changes: 80 additions & 33 deletions zstd/fse_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,29 +184,75 @@ func (s *fseDecoder) readNCount(b *byteReader, maxSymbol uint16) error {
// decSymbol contains information about a state entry,
// Including the state offset base, the output symbol and
// the number of bits to read for the low part of the destination state.
type decSymbol struct {
newState uint16
addBits uint8 // Used for symbols until transformed.
nbBits uint8
baseline uint32
// Using a composite uint64 is faster than a struct with separate members.
type decSymbol uint64

func newDecSymbol(nbits, addBits uint8, newState uint16, baseline uint32) decSymbol {
return decSymbol(nbits) | (decSymbol(addBits) << 8) | (decSymbol(newState) << 16) | (decSymbol(baseline) << 32)
}

func (d decSymbol) nbBits() uint8 {
return uint8(d)
}

func (d decSymbol) addBits() uint8 {
return uint8(d >> 8)
}

func (d decSymbol) newState() uint16 {
return uint16(d >> 16)
}

func (d decSymbol) baseline() uint32 {
return uint32(d >> 32)
}

func (d decSymbol) baselineInt() int {
return int(d >> 32)
}

func (d *decSymbol) set(nbits, addBits uint8, newState uint16, baseline uint32) {
*d = decSymbol(nbits) | (decSymbol(addBits) << 8) | (decSymbol(newState) << 16) | (decSymbol(baseline) << 32)
}

func (d *decSymbol) setNBits(nBits uint8) {
const mask = 0xffffffffffffff00
*d = (*d & mask) | decSymbol(nBits)
}

func (d *decSymbol) setAddBits(addBits uint8) {
const mask = 0xffffffffffff00ff
*d = (*d & mask) | (decSymbol(addBits) << 8)
}

func (d *decSymbol) setNewState(state uint16) {
const mask = 0xffffffff0000ffff
*d = (*d & mask) | decSymbol(state)<<16
}

func (d *decSymbol) setBaseline(baseline uint32) {
const mask = 0xffffffff
*d = (*d & mask) | decSymbol(baseline)<<32
}

func (d *decSymbol) setExt(addBits uint8, baseline uint32) {
const mask = 0xffff00ff
*d = (*d & mask) | (decSymbol(addBits) << 8) | (decSymbol(baseline) << 32)
}

// decSymbolValue returns the transformed decSymbol for the given symbol.
func decSymbolValue(symb uint8, t []baseOffset) (decSymbol, error) {
if int(symb) >= len(t) {
return decSymbol{}, fmt.Errorf("rle symbol %d >= max %d", symb, len(t))
return 0, fmt.Errorf("rle symbol %d >= max %d", symb, len(t))
}
lu := t[symb]
return decSymbol{
addBits: lu.addBits,
baseline: lu.baseLine,
}, nil
return newDecSymbol(0, lu.addBits, 0, lu.baseLine), nil
}

// setRLE will set the decoder til RLE mode.
func (s *fseDecoder) setRLE(symbol decSymbol) {
s.actualTableLog = 0
s.maxBits = symbol.addBits
s.maxBits = symbol.addBits()
s.dt[0] = symbol
}

Expand All @@ -220,7 +266,7 @@ func (s *fseDecoder) buildDtable() error {
{
for i, v := range s.norm[:s.symbolLen] {
if v == -1 {
s.dt[highThreshold].addBits = uint8(i)
s.dt[highThreshold].setAddBits(uint8(i))
highThreshold--
symbolNext[i] = 1
} else {
Expand All @@ -235,7 +281,7 @@ func (s *fseDecoder) buildDtable() error {
position := uint32(0)
for ss, v := range s.norm[:s.symbolLen] {
for i := 0; i < int(v); i++ {
s.dt[position].addBits = uint8(ss)
s.dt[position].setAddBits(uint8(ss))
position = (position + step) & tableMask
for position > highThreshold {
// lowprob area
Expand All @@ -253,11 +299,11 @@ func (s *fseDecoder) buildDtable() error {
{
tableSize := uint16(1 << s.actualTableLog)
for u, v := range s.dt[:tableSize] {
symbol := v.addBits
symbol := v.addBits()
nextState := symbolNext[symbol]
symbolNext[symbol] = nextState + 1
nBits := s.actualTableLog - byte(highBits(uint32(nextState)))
s.dt[u&maxTableMask].nbBits = nBits
s.dt[u&maxTableMask].setNBits(nBits)
newState := (nextState << nBits) - tableSize
if newState > tableSize {
return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize)
Expand All @@ -266,7 +312,7 @@ func (s *fseDecoder) buildDtable() error {
// Seems weird that this is possible with nbits > 0.
return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u)
}
s.dt[u&maxTableMask].newState = newState
s.dt[u&maxTableMask].setNewState(newState)
}
}
return nil
Expand All @@ -279,25 +325,21 @@ func (s *fseDecoder) transform(t []baseOffset) error {
tableSize := uint16(1 << s.actualTableLog)
s.maxBits = 0
for i, v := range s.dt[:tableSize] {
if int(v.addBits) >= len(t) {
return fmt.Errorf("invalid decoding table entry %d, symbol %d >= max (%d)", i, v.addBits, len(t))
add := v.addBits()
if int(add) >= len(t) {
return fmt.Errorf("invalid decoding table entry %d, symbol %d >= max (%d)", i, v.addBits(), len(t))
}
lu := t[v.addBits]
lu := t[add]
if lu.addBits > s.maxBits {
s.maxBits = lu.addBits
}
s.dt[i&maxTableMask] = decSymbol{
newState: v.newState,
nbBits: v.nbBits,
addBits: lu.addBits,
baseline: lu.baseLine,
}
v.setExt(lu.addBits, lu.baseLine)
s.dt[i] = v
}
return nil
}

type fseState struct {
// TODO: Check if *[1 << maxTablelog]decSymbol is faster.
dt []decSymbol
state decSymbol
}
Expand All @@ -312,26 +354,31 @@ func (s *fseState) init(br *bitReader, tableLog uint8, dt []decSymbol) {
// next returns the current symbol and sets the next state.
// At least tablelog bits must be available in the bit reader.
func (s *fseState) next(br *bitReader) {
lowBits := uint16(br.getBits(s.state.nbBits))
s.state = s.dt[s.state.newState+lowBits]
lowBits := uint16(br.getBits(s.state.nbBits()))
s.state = s.dt[s.state.newState()+lowBits]
}

// finished returns true if all bits have been read from the bitstream
// and the next state would require reading bits from the input.
func (s *fseState) finished(br *bitReader) bool {
return br.finished() && s.state.nbBits > 0
return br.finished() && s.state.nbBits() > 0
}

// final returns the current state symbol without decoding the next.
func (s *fseState) final() (int, uint8) {
return int(s.state.baseline), s.state.addBits
return s.state.baselineInt(), s.state.addBits()
}

// final returns the current state symbol without decoding the next.
func (s decSymbol) final() (int, uint8) {
return s.baselineInt(), s.addBits()
}

// nextFast returns the next symbol and sets the next state.
// This can only be used if no symbols are 0 bits.
// At least tablelog bits must be available in the bit reader.
func (s *fseState) nextFast(br *bitReader) (uint32, uint8) {
lowBits := uint16(br.getBitsFast(s.state.nbBits))
s.state = s.dt[s.state.newState+lowBits]
return s.state.baseline, s.state.addBits
lowBits := uint16(br.getBitsFast(s.state.nbBits()))
s.state = s.dt[s.state.newState()+lowBits]
return s.state.baseline(), s.state.addBits()
}
Loading

0 comments on commit 0e54620

Please sign in to comment.