Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Asm decoder tweaks #537

Merged
merged 3 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/cpuinfo/cpuinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ func HasBMI2() bool {
return hasBMI2
}

// DisableBMI2 will disable BMI2, for testing purposes.
// Call returned function to restore previous state.
func DisableBMI2() func() {
old := hasBMI2
hasBMI2 = false
return func() {
hasBMI2 = old
}
}

// HasBMI checks whether an x86 CPU supports both BMI1 and BMI2 extensions.
func HasBMI() bool {
return HasBMI1() && HasBMI2()
Expand Down
41 changes: 25 additions & 16 deletions zstd/_generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,16 @@ func (o options) genDecodeSeqAsm(name string) {
}

R14 := GP64()
MOVQ(ofState, R14) // copy ofState, its current value is needed below
if o.bmi2 {
tmp := GP64()
MOVQ(U32(8|(8<<8)), tmp)
BEXTRQ(tmp, ofState, R14)
} else {
MOVQ(ofState, R14) // copy ofState, its current value is needed below
SHRQ(U8(8), R14) // moB (from the ofState before its update)
MOVBQZX(R14.As8(), R14)
}

// Reload ctx
ctx := Dereference(Param("ctx"))
iteration, err := ctx.Field("iteration").Resolve()
Expand All @@ -188,8 +197,6 @@ func (o options) genDecodeSeqAsm(name string) {
Label(name + "_skip_update")

// mo = s.adjustOffset(mo, ll, moB)
SHRQ(U8(8), R14) // moB (from the ofState before its update)
MOVBQZX(R14.As8(), R14)

Comment("Adjust offset")

Expand Down Expand Up @@ -373,27 +380,27 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu
})

DX := GP64()
MOVQ(state, DX) // TODO: maybe use BEXTR?
SHRQ(U8(16), DX)
MOVWQZX(DX.As16(), DX)

if !o.bmi2 {
// TODO: Probably reasonable to kip if AX==0s
CMPQ(AX, U8(0))
JZ(LabelRef(name + "_skip"))
if o.bmi2 {
tmp := GP64()
MOVQ(U32(16|(16<<8)), tmp)
BEXTRQ(tmp, state, DX)
} else {
MOVQ(state, DX)
SHRQ(U8(16), DX)
MOVWQZX(DX.As16(), DX)
}

{
lowBits := o.getBits(name+"_getBits", AX, brValue, brBitsRead)
lowBits := o.getBits(name+"_getBits", AX, brValue, brBitsRead, LabelRef(name+"_skip_zero"))
// Check if below tablelog
assert(func(ok LabelRef) {
CMPQ(lowBits, U32(512))
JB(ok)
})
ADDQ(lowBits, DX)
Label(name + "_skip_zero")
}

Label(name + "_skip")
// Load table pointer
tablePtr := GP64()
Comment("Load ctx." + table)
Expand All @@ -413,7 +420,9 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu
MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state)
}

func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual {
// getBits will return nbits bits from brValue.
// If nbits == 0 it *may* jump to jmpZero, otherwise 0 is returned.
func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual, jmpZero LabelRef) reg.GPVirtual {
BX := GP64()
CX := reg.CL
if o.bmi2 {
Expand All @@ -423,15 +432,15 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual)
ROLQ(CX, BX)
BZHIQ(nBits, BX, BX)
} else {
CMPQ(nBits, U8(0))
JZ(jmpZero)
MOVQ(brBitsRead, CX.As64())
ADDQ(nBits, brBitsRead)
MOVQ(brValue, BX)
SHLQ(CX, BX)
MOVQ(nBits, CX.As64())
NEGQ(CX.As64())
SHRQ(CX, BX)
TESTQ(nBits, nBits)
CMOVQEQ(nBits, BX)
}
return BX
}
Expand Down
21 changes: 8 additions & 13 deletions zstd/seqdec_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,13 @@ const errorMatchLenTooBig = 2
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
//
// Please refer to seqdec_generic.go for the reference implementation.
//go:noescape
func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int

// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
//go:noescape
func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int

type sequenceDecs_decode_function = func(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int

var sequenceDecs_decode sequenceDecs_decode_function

func init() {
if cpuinfo.HasBMI2() {
sequenceDecs_decode = sequenceDecs_decode_bmi2
} else {
sequenceDecs_decode = sequenceDecs_decode_amd64
}
}

// decode sequences from the stream without the provided history.
func (s *sequenceDecs) decode(seqs []seqVals) error {
br := s.br
Expand All @@ -70,7 +60,12 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {

s.seqSize = 0

errCode := sequenceDecs_decode(s, br, &ctx)
var errCode int
if cpuinfo.HasBMI2() {
errCode = sequenceDecs_decode_bmi2(s, br, &ctx)
} else {
errCode = sequenceDecs_decode_amd64(s, br, &ctx)
}
if errCode != 0 {
i := len(seqs) - ctx.iteration
switch errCode {
Expand Down
59 changes: 25 additions & 34 deletions zstd/seqdec_amd64.s
Original file line number Diff line number Diff line change
Expand Up @@ -142,30 +142,30 @@ sequenceDecs_decode_amd64_fill_3_byte_by_byte:
JMP sequenceDecs_decode_amd64_fill_3_byte_by_byte

sequenceDecs_decode_amd64_fill_3_end:
MOVQ R11, (SP)
MOVQ R9, AX
MOVQ ctx+16(FP), CX
CMPQ 96(CX), $0x00
JZ sequenceDecs_decode_amd64_skip_update
MOVQ R11, (SP)
MOVQ R9, AX
SHRQ $0x08, AX
MOVBQZX AL, AX
MOVQ ctx+16(FP), CX
CMPQ 96(CX), $0x00
JZ sequenceDecs_decode_amd64_skip_update

// Update Literal Length State
MOVBQZX DI, R11
SHRQ $0x10, DI
MOVWQZX DI, DI
CMPQ R11, $0x00
JZ sequenceDecs_decode_amd64_llState_updateState_skip
JZ sequenceDecs_decode_amd64_llState_updateState_skip_zero
MOVQ BX, CX
ADDQ R11, BX
MOVQ DX, R12
SHLQ CL, R12
MOVQ R11, CX
NEGQ CX
SHRQ CL, R12
TESTQ R11, R11
CMOVQEQ R11, R12
ADDQ R12, DI

sequenceDecs_decode_amd64_llState_updateState_skip:
sequenceDecs_decode_amd64_llState_updateState_skip_zero:
// Load ctx.llTable
MOVQ ctx+16(FP), CX
MOVQ (CX), CX
Expand All @@ -176,19 +176,17 @@ sequenceDecs_decode_amd64_llState_updateState_skip:
SHRQ $0x10, R8
MOVWQZX R8, R8
CMPQ R11, $0x00
JZ sequenceDecs_decode_amd64_mlState_updateState_skip
JZ sequenceDecs_decode_amd64_mlState_updateState_skip_zero
MOVQ BX, CX
ADDQ R11, BX
MOVQ DX, R12
SHLQ CL, R12
MOVQ R11, CX
NEGQ CX
SHRQ CL, R12
TESTQ R11, R11
CMOVQEQ R11, R12
ADDQ R12, R8

sequenceDecs_decode_amd64_mlState_updateState_skip:
sequenceDecs_decode_amd64_mlState_updateState_skip_zero:
// Load ctx.mlTable
MOVQ ctx+16(FP), CX
MOVQ 24(CX), CX
Expand All @@ -199,28 +197,23 @@ sequenceDecs_decode_amd64_mlState_updateState_skip:
SHRQ $0x10, R9
MOVWQZX R9, R9
CMPQ R11, $0x00
JZ sequenceDecs_decode_amd64_ofState_updateState_skip
JZ sequenceDecs_decode_amd64_ofState_updateState_skip_zero
MOVQ BX, CX
ADDQ R11, BX
MOVQ DX, R12
SHLQ CL, R12
MOVQ R11, CX
NEGQ CX
SHRQ CL, R12
TESTQ R11, R11
CMOVQEQ R11, R12
ADDQ R12, R9

sequenceDecs_decode_amd64_ofState_updateState_skip:
sequenceDecs_decode_amd64_ofState_updateState_skip_zero:
// Load ctx.ofTable
MOVQ ctx+16(FP), CX
MOVQ 48(CX), CX
MOVQ (CX)(R9*8), R9

sequenceDecs_decode_amd64_skip_update:
SHRQ $0x08, AX
MOVBQZX AL, AX

// Adjust offset
MOVQ s+0(FP), CX
MOVQ 16(R10), R11
Expand Down Expand Up @@ -444,16 +437,17 @@ sequenceDecs_decode_bmi2_fill_3_byte_by_byte:
JMP sequenceDecs_decode_bmi2_fill_3_byte_by_byte

sequenceDecs_decode_bmi2_fill_3_end:
MOVQ R10, (SP)
MOVQ R8, R10
MOVQ ctx+16(FP), CX
CMPQ 96(CX), $0x00
JZ sequenceDecs_decode_bmi2_skip_update
MOVQ R10, (SP)
MOVQ $0x00000808, CX
BEXTRQ CX, R8, R10
MOVQ ctx+16(FP), CX
CMPQ 96(CX), $0x00
JZ sequenceDecs_decode_bmi2_skip_update

// Update Literal Length State
MOVBQZX SI, R11
SHRQ $0x10, SI
MOVWQZX SI, SI
MOVQ $0x00001010, CX
BEXTRQ CX, SI, SI
LEAQ (DX)(R11*1), CX
MOVQ AX, R12
MOVQ CX, DX
Expand All @@ -468,8 +462,8 @@ sequenceDecs_decode_bmi2_fill_3_end:

// Update Match Length State
MOVBQZX DI, R11
SHRQ $0x10, DI
MOVWQZX DI, DI
MOVQ $0x00001010, CX
BEXTRQ CX, DI, DI
LEAQ (DX)(R11*1), CX
MOVQ AX, R12
MOVQ CX, DX
Expand All @@ -484,8 +478,8 @@ sequenceDecs_decode_bmi2_fill_3_end:

// Update Offset State
MOVBQZX R8, R11
SHRQ $0x10, R8
MOVWQZX R8, R8
MOVQ $0x00001010, CX
BEXTRQ CX, R8, R8
LEAQ (DX)(R11*1), CX
MOVQ AX, R12
MOVQ CX, DX
Expand All @@ -499,9 +493,6 @@ sequenceDecs_decode_bmi2_fill_3_end:
MOVQ (CX)(R8*8), R8

sequenceDecs_decode_bmi2_skip_update:
SHRQ $0x08, R10
MOVBQZX R10, R10

// Adjust offset
MOVQ s+0(FP), CX
MOVQ 16(R9), R11
Expand Down
Loading