Skip to content

Commit

Permalink
zstd: faster next state update in BMI2 version of decode (#593)
Browse files Browse the repository at this point in the history
Use the Go-code approach: use single getBits to obtain three bitfields.
  • Loading branch information
WojciechMula authored May 12, 2022
1 parent 6ebbb85 commit 348514c
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 159 deletions.
108 changes: 102 additions & 6 deletions zstd/_generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,48 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute

// Update states, max tablelog 28
{
Comment("Update Literal Length State")
o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable")
Comment("Update Match Length State")
o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable")
Comment("Update Offset State")
o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable")
if o.bmi2 {
// Get total number of bits (it is safe, as nBits is <= 9, thus 3*9 < 255)
total := GP64()
LEAQ(Mem{Base: llState, Index: mlState, Scale: 1}, total)
ADDQ(ofState, total)
MOVBQZX(total.As8(), total) // total = llState.As8() + mlState.As8() + ofState.As8()

// Read `total` bits
bits := o.getBitsValue(name+"_getBits", total, brValue, brBitsRead)

// Update states
Comment("Update Offset State")
{
nBits := ofState // Note: SHRXQ uses lower 6 bits of shift amount and BZHIQ lower 8 bits of count
lowBits := GP64()
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
SHRXQ(nBits, bits, bits) // bits >>= nBits
o.nextState(name+"_ofState", ofState, lowBits, "ofTable")
}
Comment("Update Match Length State")
{
nBits := mlState
lowBits := GP64()
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
SHRXQ(nBits, bits, bits) // lowBits >>= nBits
o.nextState(name+"_mlState", mlState, lowBits, "mlTable")
}
Comment("Update Literal Length State")
{
nBits := llState
lowBits := GP64()
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
o.nextState(name+"_llState", llState, lowBits, "llTable")
}
} else {
Comment("Update Literal Length State")
o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable")
Comment("Update Match Length State")
o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable")
Comment("Update Offset State")
o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable")
}
}
Label(name + "_skip_update")

Expand Down Expand Up @@ -624,6 +660,39 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu
MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state)
}

func (o options) nextState(name string, state, lowBits reg.GPVirtual, table string) {
DX := GP64()
if o.bmi2 {
tmp := GP64()
MOVQ(U32(16|(16<<8)), tmp)
BEXTRQ(tmp, state, DX)
} else {
MOVQ(state, DX)
SHRQ(U8(16), DX)
MOVWQZX(DX.As16(), DX)
}

ADDQ(lowBits, DX)

// Load table pointer
tablePtr := GP64()
Comment("Load ctx." + table)
ctx := Dereference(Param("ctx"))
tableA, err := ctx.Field(table).Base().Resolve()
if err != nil {
panic(err)
}
MOVQ(tableA.Addr, tablePtr)

// Check if below tablelog
assert(func(ok LabelRef) {
CMPQ(DX, U32(512))
JB(ok)
})
// Load new state
MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state)
}

// getBits will return nbits bits from brValue.
// If nbits == 0 it *may* jump to jmpZero, otherwise 0 is returned.
func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual, jmpZero LabelRef) reg.GPVirtual {
Expand All @@ -649,6 +718,33 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual,
return BX
}

// getBits will return nbits bits from brValue.
// If nbits == 0 then 0 is returned.
func (o options) getBitsValue(name string, nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual {
BX := GP64()
CX := reg.CL
if o.bmi2 {
LEAQ(Mem{Base: brBitsRead, Index: nBits, Scale: 1}, CX.As64())
MOVQ(brValue, BX)
MOVQ(CX.As64(), brBitsRead)
ROLQ(CX, BX)
BZHIQ(nBits, BX, BX)
} else {
XORQ(BX, BX)
CMPQ(nBits, U8(0))
JZ(LabelRef(name + "_get_bits_value_zero"))
MOVQ(brBitsRead, CX.As64())
ADDQ(nBits, brBitsRead)
MOVQ(brValue, BX)
SHLQ(CX, BX)
MOVQ(nBits, CX.As64())
NEGQ(CX.As64())
SHRQ(CX, BX)
Label(name + "_get_bits_value_zero")
}
return BX
}

func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual, offsets *[3]reg.GPVirtual) (offset reg.GPVirtual) {
offset = GP64()
MOVQ(moP, offset)
Expand Down
11 changes: 6 additions & 5 deletions zstd/seqdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
}
}
}

// Add final literals
copy(out[t:], s.literals)
if debugDecoder {
Expand All @@ -203,12 +204,11 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {

// decode sequences from the stream with the provided history.
func (s *sequenceDecs) decodeSync(hist []byte) error {
if true {
supported, err := s.decodeSyncSimple(hist)
if supported {
return err
}
supported, err := s.decodeSyncSimple(hist)
if supported {
return err
}

br := s.br
seqs := s.nSeqs
startSize := len(s.out)
Expand Down Expand Up @@ -396,6 +396,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
ofState = ofTable[ofState.newState()&maxTableMask]
} else {
bits := br.get32BitsFast(nBits)

lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
llState = llTable[(llState.newState()+lowBits)&maxTableMask]

Expand Down
Loading

0 comments on commit 348514c

Please sign in to comment.