Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: faster next state update in BMI2 version of decode #593

Merged
merged 1 commit into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 102 additions & 6 deletions zstd/_generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,48 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute

// Update states, max tablelog 28
{
Comment("Update Literal Length State")
o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable")
Comment("Update Match Length State")
o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable")
Comment("Update Offset State")
o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable")
if o.bmi2 {
// Get total number of bits (it is safe, as nBits is <= 9, thus 3*9 < 255)
total := GP64()
LEAQ(Mem{Base: llState, Index: mlState, Scale: 1}, total)
ADDQ(ofState, total)
MOVBQZX(total.As8(), total) // total = llState.As8() + mlState.As8() + ofState.As8()

// Read `total` bits
bits := o.getBitsValue(name+"_getBits", total, brValue, brBitsRead)

// Update states
Comment("Update Offset State")
{
nBits := ofState // Note: SHRXQ uses lower 6 bits of shift amount and BZHIQ lower 8 bits of count
lowBits := GP64()
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
SHRXQ(nBits, bits, bits) // bits >>= nBits
o.nextState(name+"_ofState", ofState, lowBits, "ofTable")
}
Comment("Update Match Length State")
{
nBits := mlState
lowBits := GP64()
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
SHRXQ(nBits, bits, bits) // lowBits >>= nBits
o.nextState(name+"_mlState", mlState, lowBits, "mlTable")
}
Comment("Update Literal Length State")
{
nBits := llState
lowBits := GP64()
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
o.nextState(name+"_llState", llState, lowBits, "llTable")
}
} else {
Comment("Update Literal Length State")
o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable")
Comment("Update Match Length State")
o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable")
Comment("Update Offset State")
o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable")
}
}
Label(name + "_skip_update")

Expand Down Expand Up @@ -624,6 +660,39 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu
MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state)
}

func (o options) nextState(name string, state, lowBits reg.GPVirtual, table string) {
DX := GP64()
if o.bmi2 {
tmp := GP64()
MOVQ(U32(16|(16<<8)), tmp)
BEXTRQ(tmp, state, DX)
} else {
MOVQ(state, DX)
SHRQ(U8(16), DX)
MOVWQZX(DX.As16(), DX)
}

ADDQ(lowBits, DX)

// Load table pointer
tablePtr := GP64()
Comment("Load ctx." + table)
ctx := Dereference(Param("ctx"))
tableA, err := ctx.Field(table).Base().Resolve()
if err != nil {
panic(err)
}
MOVQ(tableA.Addr, tablePtr)

// Check if below tablelog
assert(func(ok LabelRef) {
CMPQ(DX, U32(512))
JB(ok)
})
// Load new state
MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state)
}

// getBits will return nbits bits from brValue.
// If nbits == 0 it *may* jump to jmpZero, otherwise 0 is returned.
func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual, jmpZero LabelRef) reg.GPVirtual {
Expand All @@ -649,6 +718,33 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual,
return BX
}

// getBits will return nbits bits from brValue.
// If nbits == 0 then 0 is returned.
func (o options) getBitsValue(name string, nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual {
BX := GP64()
CX := reg.CL
if o.bmi2 {
LEAQ(Mem{Base: brBitsRead, Index: nBits, Scale: 1}, CX.As64())
MOVQ(brValue, BX)
MOVQ(CX.As64(), brBitsRead)
ROLQ(CX, BX)
BZHIQ(nBits, BX, BX)
} else {
XORQ(BX, BX)
CMPQ(nBits, U8(0))
JZ(LabelRef(name + "_get_bits_value_zero"))
MOVQ(brBitsRead, CX.As64())
ADDQ(nBits, brBitsRead)
MOVQ(brValue, BX)
SHLQ(CX, BX)
MOVQ(nBits, CX.As64())
NEGQ(CX.As64())
SHRQ(CX, BX)
Label(name + "_get_bits_value_zero")
}
return BX
}

func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual, offsets *[3]reg.GPVirtual) (offset reg.GPVirtual) {
offset = GP64()
MOVQ(moP, offset)
Expand Down
11 changes: 6 additions & 5 deletions zstd/seqdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
}
}
}

// Add final literals
copy(out[t:], s.literals)
if debugDecoder {
Expand All @@ -203,12 +204,11 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {

// decode sequences from the stream with the provided history.
func (s *sequenceDecs) decodeSync(hist []byte) error {
if true {
supported, err := s.decodeSyncSimple(hist)
if supported {
return err
}
supported, err := s.decodeSyncSimple(hist)
if supported {
return err
}

br := s.br
seqs := s.nSeqs
startSize := len(s.out)
Expand Down Expand Up @@ -396,6 +396,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
ofState = ofTable[ofState.newState()&maxTableMask]
} else {
bits := br.get32BitsFast(nBits)

lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
llState = llTable[(llState.newState()+lowBits)&maxTableMask]

Expand Down
Loading