Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd decode: Use compound decSymbol #144

Merged
merged 3 commits into from
Aug 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 87 additions & 80 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -966,103 +966,110 @@ func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder) {
// We don't predefine them, since this also tests our transformations.
// Reference from here: https://github.com/facebook/zstd/blob/ededcfca57366461021c922720878c81a5854a0a/lib/decompress/zstd_decompress_block.c#L234
func TestPredefTables(t *testing.T) {
x := func(nextState uint16, nbAddBits, nbBits uint8, baseVal uint32) decSymbol {
return newDecSymbol(nbBits, nbAddBits, nextState, baseVal)
}
for i := range fsePredef[:] {
var want []decSymbol
switch tableIndex(i) {
case tableLiteralLengths:
want = []decSymbol{
/* nextState, nbAddBits, nbBits, baseVal */
{0, 0, 4, 0}, {16, 0, 4, 0},
{32, 0, 5, 1}, {0, 0, 5, 3},
{0, 0, 5, 4}, {0, 0, 5, 6},
{0, 0, 5, 7}, {0, 0, 5, 9},
{0, 0, 5, 10}, {0, 0, 5, 12},
{0, 0, 6, 14}, {0, 1, 5, 16},
{0, 1, 5, 20}, {0, 1, 5, 22},
{0, 2, 5, 28}, {0, 3, 5, 32},
{0, 4, 5, 48}, {32, 6, 5, 64},
{0, 7, 5, 128}, {0, 8, 6, 256},
{0, 10, 6, 1024}, {0, 12, 6, 4096},
{32, 0, 4, 0}, {0, 0, 4, 1},
{0, 0, 5, 2}, {32, 0, 5, 4},
{0, 0, 5, 5}, {32, 0, 5, 7},
{0, 0, 5, 8}, {32, 0, 5, 10},
{0, 0, 5, 11}, {0, 0, 6, 13},
{32, 1, 5, 16}, {0, 1, 5, 18},
{32, 1, 5, 22}, {0, 2, 5, 24},
{32, 3, 5, 32}, {0, 3, 5, 40},
{0, 6, 4, 64}, {16, 6, 4, 64},
{32, 7, 5, 128}, {0, 9, 6, 512},
{0, 11, 6, 2048}, {48, 0, 4, 0},
{16, 0, 4, 1}, {32, 0, 5, 2},
{32, 0, 5, 3}, {32, 0, 5, 5},
{32, 0, 5, 6}, {32, 0, 5, 8},
{32, 0, 5, 9}, {32, 0, 5, 11},
{32, 0, 5, 12}, {0, 0, 6, 15},
{32, 1, 5, 18}, {32, 1, 5, 20},
{32, 2, 5, 24}, {32, 2, 5, 28},
{32, 3, 5, 40}, {32, 4, 5, 48},
{0, 16, 6, 65536}, {0, 15, 6, 32768},
{0, 14, 6, 16384}, {0, 13, 6, 8192}}
x(0, 0, 4, 0), x(16, 0, 4, 0),
x(32, 0, 5, 1), x(0, 0, 5, 3),
x(0, 0, 5, 4), x(0, 0, 5, 6),
x(0, 0, 5, 7), x(0, 0, 5, 9),
x(0, 0, 5, 10), x(0, 0, 5, 12),
x(0, 0, 6, 14), x(0, 1, 5, 16),
x(0, 1, 5, 20), x(0, 1, 5, 22),
x(0, 2, 5, 28), x(0, 3, 5, 32),
x(0, 4, 5, 48), x(32, 6, 5, 64),
x(0, 7, 5, 128), x(0, 8, 6, 256),
x(0, 10, 6, 1024), x(0, 12, 6, 4096),
x(32, 0, 4, 0), x(0, 0, 4, 1),
x(0, 0, 5, 2), x(32, 0, 5, 4),
x(0, 0, 5, 5), x(32, 0, 5, 7),
x(0, 0, 5, 8), x(32, 0, 5, 10),
x(0, 0, 5, 11), x(0, 0, 6, 13),
x(32, 1, 5, 16), x(0, 1, 5, 18),
x(32, 1, 5, 22), x(0, 2, 5, 24),
x(32, 3, 5, 32), x(0, 3, 5, 40),
x(0, 6, 4, 64), x(16, 6, 4, 64),
x(32, 7, 5, 128), x(0, 9, 6, 512),
x(0, 11, 6, 2048), x(48, 0, 4, 0),
x(16, 0, 4, 1), x(32, 0, 5, 2),
x(32, 0, 5, 3), x(32, 0, 5, 5),
x(32, 0, 5, 6), x(32, 0, 5, 8),
x(32, 0, 5, 9), x(32, 0, 5, 11),
x(32, 0, 5, 12), x(0, 0, 6, 15),
x(32, 1, 5, 18), x(32, 1, 5, 20),
x(32, 2, 5, 24), x(32, 2, 5, 28),
x(32, 3, 5, 40), x(32, 4, 5, 48),
x(0, 16, 6, 65536), x(0, 15, 6, 32768),
x(0, 14, 6, 16384), x(0, 13, 6, 8192),
}
case tableOffsets:
want = []decSymbol{
/* nextState, nbAddBits, nbBits, baseVal */
{0, 0, 5, 0}, {0, 6, 4, 61},
{0, 9, 5, 509}, {0, 15, 5, 32765},
{0, 21, 5, 2097149}, {0, 3, 5, 5},
{0, 7, 4, 125}, {0, 12, 5, 4093},
{0, 18, 5, 262141}, {0, 23, 5, 8388605},
{0, 5, 5, 29}, {0, 8, 4, 253},
{0, 14, 5, 16381}, {0, 20, 5, 1048573},
{0, 2, 5, 1}, {16, 7, 4, 125},
{0, 11, 5, 2045}, {0, 17, 5, 131069},
{0, 22, 5, 4194301}, {0, 4, 5, 13},
{16, 8, 4, 253}, {0, 13, 5, 8189},
{0, 19, 5, 524285}, {0, 1, 5, 1},
{16, 6, 4, 61}, {0, 10, 5, 1021},
{0, 16, 5, 65533}, {0, 28, 5, 268435453},
{0, 27, 5, 134217725}, {0, 26, 5, 67108861},
{0, 25, 5, 33554429}, {0, 24, 5, 16777213}}
x(0, 0, 5, 0), x(0, 6, 4, 61),
x(0, 9, 5, 509), x(0, 15, 5, 32765),
x(0, 21, 5, 2097149), x(0, 3, 5, 5),
x(0, 7, 4, 125), x(0, 12, 5, 4093),
x(0, 18, 5, 262141), x(0, 23, 5, 8388605),
x(0, 5, 5, 29), x(0, 8, 4, 253),
x(0, 14, 5, 16381), x(0, 20, 5, 1048573),
x(0, 2, 5, 1), x(16, 7, 4, 125),
x(0, 11, 5, 2045), x(0, 17, 5, 131069),
x(0, 22, 5, 4194301), x(0, 4, 5, 13),
x(16, 8, 4, 253), x(0, 13, 5, 8189),
x(0, 19, 5, 524285), x(0, 1, 5, 1),
x(16, 6, 4, 61), x(0, 10, 5, 1021),
x(0, 16, 5, 65533), x(0, 28, 5, 268435453),
x(0, 27, 5, 134217725), x(0, 26, 5, 67108861),
x(0, 25, 5, 33554429), x(0, 24, 5, 16777213),
}
case tableMatchLengths:
want = []decSymbol{
/* nextState, nbAddBits, nbBits, baseVal */
{0, 0, 6, 3}, {0, 0, 4, 4},
{32, 0, 5, 5}, {0, 0, 5, 6},
{0, 0, 5, 8}, {0, 0, 5, 9},
{0, 0, 5, 11}, {0, 0, 6, 13},
{0, 0, 6, 16}, {0, 0, 6, 19},
{0, 0, 6, 22}, {0, 0, 6, 25},
{0, 0, 6, 28}, {0, 0, 6, 31},
{0, 0, 6, 34}, {0, 1, 6, 37},
{0, 1, 6, 41}, {0, 2, 6, 47},
{0, 3, 6, 59}, {0, 4, 6, 83},
{0, 7, 6, 131}, {0, 9, 6, 515},
{16, 0, 4, 4}, {0, 0, 4, 5},
{32, 0, 5, 6}, {0, 0, 5, 7},
{32, 0, 5, 9}, {0, 0, 5, 10},
{0, 0, 6, 12}, {0, 0, 6, 15},
{0, 0, 6, 18}, {0, 0, 6, 21},
{0, 0, 6, 24}, {0, 0, 6, 27},
{0, 0, 6, 30}, {0, 0, 6, 33},
{0, 1, 6, 35}, {0, 1, 6, 39},
{0, 2, 6, 43}, {0, 3, 6, 51},
{0, 4, 6, 67}, {0, 5, 6, 99},
{0, 8, 6, 259}, {32, 0, 4, 4},
{48, 0, 4, 4}, {16, 0, 4, 5},
{32, 0, 5, 7}, {32, 0, 5, 8},
{32, 0, 5, 10}, {32, 0, 5, 11},
{0, 0, 6, 14}, {0, 0, 6, 17},
{0, 0, 6, 20}, {0, 0, 6, 23},
{0, 0, 6, 26}, {0, 0, 6, 29},
{0, 0, 6, 32}, {0, 16, 6, 65539},
{0, 15, 6, 32771}, {0, 14, 6, 16387},
{0, 13, 6, 8195}, {0, 12, 6, 4099},
{0, 11, 6, 2051}, {0, 10, 6, 1027},
x(0, 0, 6, 3), x(0, 0, 4, 4),
x(32, 0, 5, 5), x(0, 0, 5, 6),
x(0, 0, 5, 8), x(0, 0, 5, 9),
x(0, 0, 5, 11), x(0, 0, 6, 13),
x(0, 0, 6, 16), x(0, 0, 6, 19),
x(0, 0, 6, 22), x(0, 0, 6, 25),
x(0, 0, 6, 28), x(0, 0, 6, 31),
x(0, 0, 6, 34), x(0, 1, 6, 37),
x(0, 1, 6, 41), x(0, 2, 6, 47),
x(0, 3, 6, 59), x(0, 4, 6, 83),
x(0, 7, 6, 131), x(0, 9, 6, 515),
x(16, 0, 4, 4), x(0, 0, 4, 5),
x(32, 0, 5, 6), x(0, 0, 5, 7),
x(32, 0, 5, 9), x(0, 0, 5, 10),
x(0, 0, 6, 12), x(0, 0, 6, 15),
x(0, 0, 6, 18), x(0, 0, 6, 21),
x(0, 0, 6, 24), x(0, 0, 6, 27),
x(0, 0, 6, 30), x(0, 0, 6, 33),
x(0, 1, 6, 35), x(0, 1, 6, 39),
x(0, 2, 6, 43), x(0, 3, 6, 51),
x(0, 4, 6, 67), x(0, 5, 6, 99),
x(0, 8, 6, 259), x(32, 0, 4, 4),
x(48, 0, 4, 4), x(16, 0, 4, 5),
x(32, 0, 5, 7), x(32, 0, 5, 8),
x(32, 0, 5, 10), x(32, 0, 5, 11),
x(0, 0, 6, 14), x(0, 0, 6, 17),
x(0, 0, 6, 20), x(0, 0, 6, 23),
x(0, 0, 6, 26), x(0, 0, 6, 29),
x(0, 0, 6, 32), x(0, 16, 6, 65539),
x(0, 15, 6, 32771), x(0, 14, 6, 16387),
x(0, 13, 6, 8195), x(0, 12, 6, 4099),
x(0, 11, 6, 2051), x(0, 10, 6, 1027),
}
}
pre := fsePredef[i]
got := pre.dt[:1<<pre.actualTableLog]
if !reflect.DeepEqual(got, want) {
t.Logf("want: %v", want)
t.Logf("got : %v", got)
t.Errorf("Predefined table %d incorrect, len(got) = %d, len(want) = %d", i, len(got), len(want))
}
}
Expand Down
113 changes: 80 additions & 33 deletions zstd/fse_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,29 +184,75 @@ func (s *fseDecoder) readNCount(b *byteReader, maxSymbol uint16) error {
// decSymbol contains information about a state entry,
// Including the state offset base, the output symbol and
// the number of bits to read for the low part of the destination state.
type decSymbol struct {
newState uint16
addBits uint8 // Used for symbols until transformed.
nbBits uint8
baseline uint32
// Using a composite uint64 is faster than a struct with separate members.
type decSymbol uint64

func newDecSymbol(nbits, addBits uint8, newState uint16, baseline uint32) decSymbol {
return decSymbol(nbits) | (decSymbol(addBits) << 8) | (decSymbol(newState) << 16) | (decSymbol(baseline) << 32)
}

func (d decSymbol) nbBits() uint8 {
return uint8(d)
}

func (d decSymbol) addBits() uint8 {
return uint8(d >> 8)
}

func (d decSymbol) newState() uint16 {
return uint16(d >> 16)
}

func (d decSymbol) baseline() uint32 {
return uint32(d >> 32)
}

func (d decSymbol) baselineInt() int {
return int(d >> 32)
}

func (d *decSymbol) set(nbits, addBits uint8, newState uint16, baseline uint32) {
*d = decSymbol(nbits) | (decSymbol(addBits) << 8) | (decSymbol(newState) << 16) | (decSymbol(baseline) << 32)
}

func (d *decSymbol) setNBits(nBits uint8) {
const mask = 0xffffffffffffff00
*d = (*d & mask) | decSymbol(nBits)
}

func (d *decSymbol) setAddBits(addBits uint8) {
const mask = 0xffffffffffff00ff
*d = (*d & mask) | (decSymbol(addBits) << 8)
}

func (d *decSymbol) setNewState(state uint16) {
const mask = 0xffffffff0000ffff
*d = (*d & mask) | decSymbol(state)<<16
}

func (d *decSymbol) setBaseline(baseline uint32) {
const mask = 0xffffffff
*d = (*d & mask) | decSymbol(baseline)<<32
}

func (d *decSymbol) setExt(addBits uint8, baseline uint32) {
const mask = 0xffff00ff
*d = (*d & mask) | (decSymbol(addBits) << 8) | (decSymbol(baseline) << 32)
}

// decSymbolValue returns the transformed decSymbol for the given symbol.
func decSymbolValue(symb uint8, t []baseOffset) (decSymbol, error) {
if int(symb) >= len(t) {
return decSymbol{}, fmt.Errorf("rle symbol %d >= max %d", symb, len(t))
return 0, fmt.Errorf("rle symbol %d >= max %d", symb, len(t))
}
lu := t[symb]
return decSymbol{
addBits: lu.addBits,
baseline: lu.baseLine,
}, nil
return newDecSymbol(0, lu.addBits, 0, lu.baseLine), nil
}

// setRLE will set the decoder til RLE mode.
func (s *fseDecoder) setRLE(symbol decSymbol) {
s.actualTableLog = 0
s.maxBits = symbol.addBits
s.maxBits = symbol.addBits()
s.dt[0] = symbol
}

Expand All @@ -220,7 +266,7 @@ func (s *fseDecoder) buildDtable() error {
{
for i, v := range s.norm[:s.symbolLen] {
if v == -1 {
s.dt[highThreshold].addBits = uint8(i)
s.dt[highThreshold].setAddBits(uint8(i))
highThreshold--
symbolNext[i] = 1
} else {
Expand All @@ -235,7 +281,7 @@ func (s *fseDecoder) buildDtable() error {
position := uint32(0)
for ss, v := range s.norm[:s.symbolLen] {
for i := 0; i < int(v); i++ {
s.dt[position].addBits = uint8(ss)
s.dt[position].setAddBits(uint8(ss))
position = (position + step) & tableMask
for position > highThreshold {
// lowprob area
Expand All @@ -253,11 +299,11 @@ func (s *fseDecoder) buildDtable() error {
{
tableSize := uint16(1 << s.actualTableLog)
for u, v := range s.dt[:tableSize] {
symbol := v.addBits
symbol := v.addBits()
nextState := symbolNext[symbol]
symbolNext[symbol] = nextState + 1
nBits := s.actualTableLog - byte(highBits(uint32(nextState)))
s.dt[u&maxTableMask].nbBits = nBits
s.dt[u&maxTableMask].setNBits(nBits)
newState := (nextState << nBits) - tableSize
if newState > tableSize {
return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize)
Expand All @@ -266,7 +312,7 @@ func (s *fseDecoder) buildDtable() error {
// Seems weird that this is possible with nbits > 0.
return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u)
}
s.dt[u&maxTableMask].newState = newState
s.dt[u&maxTableMask].setNewState(newState)
}
}
return nil
Expand All @@ -279,25 +325,21 @@ func (s *fseDecoder) transform(t []baseOffset) error {
tableSize := uint16(1 << s.actualTableLog)
s.maxBits = 0
for i, v := range s.dt[:tableSize] {
if int(v.addBits) >= len(t) {
return fmt.Errorf("invalid decoding table entry %d, symbol %d >= max (%d)", i, v.addBits, len(t))
add := v.addBits()
if int(add) >= len(t) {
return fmt.Errorf("invalid decoding table entry %d, symbol %d >= max (%d)", i, v.addBits(), len(t))
}
lu := t[v.addBits]
lu := t[add]
if lu.addBits > s.maxBits {
s.maxBits = lu.addBits
}
s.dt[i&maxTableMask] = decSymbol{
newState: v.newState,
nbBits: v.nbBits,
addBits: lu.addBits,
baseline: lu.baseLine,
}
v.setExt(lu.addBits, lu.baseLine)
s.dt[i] = v
}
return nil
}

type fseState struct {
// TODO: Check if *[1 << maxTablelog]decSymbol is faster.
dt []decSymbol
state decSymbol
}
Expand All @@ -312,26 +354,31 @@ func (s *fseState) init(br *bitReader, tableLog uint8, dt []decSymbol) {
// next returns the current symbol and sets the next state.
// At least tablelog bits must be available in the bit reader.
func (s *fseState) next(br *bitReader) {
lowBits := uint16(br.getBits(s.state.nbBits))
s.state = s.dt[s.state.newState+lowBits]
lowBits := uint16(br.getBits(s.state.nbBits()))
s.state = s.dt[s.state.newState()+lowBits]
}

// finished returns true if all bits have been read from the bitstream
// and the next state would require reading bits from the input.
func (s *fseState) finished(br *bitReader) bool {
return br.finished() && s.state.nbBits > 0
return br.finished() && s.state.nbBits() > 0
}

// final returns the current state symbol without decoding the next.
func (s *fseState) final() (int, uint8) {
return int(s.state.baseline), s.state.addBits
return s.state.baselineInt(), s.state.addBits()
}

// final returns the current state symbol without decoding the next.
func (s decSymbol) final() (int, uint8) {
return s.baselineInt(), s.addBits()
}

// nextFast returns the next symbol and sets the next state.
// This can only be used if no symbols are 0 bits.
// At least tablelog bits must be available in the bit reader.
func (s *fseState) nextFast(br *bitReader) (uint32, uint8) {
lowBits := uint16(br.getBitsFast(s.state.nbBits))
s.state = s.dt[s.state.newState+lowBits]
return s.state.baseline, s.state.addBits
lowBits := uint16(br.getBitsFast(s.state.nbBits()))
s.state = s.dt[s.state.newState()+lowBits]
return s.state.baseline(), s.state.addBits()
}
Loading