diff --git a/internal/cpuinfo/cpuinfo.go b/internal/cpuinfo/cpuinfo.go index f85345cc0c..3954c51219 100644 --- a/internal/cpuinfo/cpuinfo.go +++ b/internal/cpuinfo/cpuinfo.go @@ -15,6 +15,16 @@ func HasBMI2() bool { return hasBMI2 } +// DisableBMI2 will disable BMI2, for testing purposes. +// Call returned function to restore previous state. +func DisableBMI2() func() { + old := hasBMI2 + hasBMI2 = false + return func() { + hasBMI2 = old + } +} + // HasBMI checks whether an x86 CPU supports both BMI1 and BMI2 extensions. func HasBMI() bool { return HasBMI1() && HasBMI2() diff --git a/zstd/_generate/gen.go b/zstd/_generate/gen.go index 084513ee6c..15d5bce525 100644 --- a/zstd/_generate/gen.go +++ b/zstd/_generate/gen.go @@ -165,7 +165,16 @@ func (o options) genDecodeSeqAsm(name string) { } R14 := GP64() - MOVQ(ofState, R14) // copy ofState, its current value is needed below + if o.bmi2 { + tmp := GP64() + MOVQ(U32(8|(8<<8)), tmp) + BEXTRQ(tmp, ofState, R14) + } else { + MOVQ(ofState, R14) // copy ofState, its current value is needed below + SHRQ(U8(8), R14) // moB (from the ofState before its update) + MOVBQZX(R14.As8(), R14) + } + // Reload ctx ctx := Dereference(Param("ctx")) iteration, err := ctx.Field("iteration").Resolve() @@ -188,8 +197,6 @@ func (o options) genDecodeSeqAsm(name string) { Label(name + "_skip_update") // mo = s.adjustOffset(mo, ll, moB) - SHRQ(U8(8), R14) // moB (from the ofState before its update) - MOVBQZX(R14.As8(), R14) Comment("Adjust offset") @@ -373,27 +380,27 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu }) DX := GP64() - MOVQ(state, DX) // TODO: maybe use BEXTR? - SHRQ(U8(16), DX) - MOVWQZX(DX.As16(), DX) - - if !o.bmi2 { - // TODO: Probably reasonable to kip if AX==0s - CMPQ(AX, U8(0)) - JZ(LabelRef(name + "_skip")) + if o.bmi2 { + tmp := GP64() + MOVQ(U32(16|(16<<8)), tmp) + BEXTRQ(tmp, state, DX) + } else { + MOVQ(state, DX) + SHRQ(U8(16), DX) + MOVWQZX(DX.As16(), DX) } { - lowBits := o.getBits(name+"_getBits", AX, brValue, brBitsRead) + lowBits := o.getBits(name+"_getBits", AX, brValue, brBitsRead, LabelRef(name+"_skip_zero")) // Check if below tablelog assert(func(ok LabelRef) { CMPQ(lowBits, U32(512)) JB(ok) }) ADDQ(lowBits, DX) + Label(name + "_skip_zero") } - Label(name + "_skip") // Load table pointer tablePtr := GP64() Comment("Load ctx." + table) @@ -413,7 +420,9 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state) } -func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual { +// getBits will return nbits bits from brValue. +// If nbits == 0 it *may* jump to jmpZero, otherwise 0 is returned. +func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual, jmpZero LabelRef) reg.GPVirtual { BX := GP64() CX := reg.CL if o.bmi2 { @@ -423,6 +432,8 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual) ROLQ(CX, BX) BZHIQ(nBits, BX, BX) } else { + CMPQ(nBits, U8(0)) + JZ(jmpZero) MOVQ(brBitsRead, CX.As64()) ADDQ(nBits, brBitsRead) MOVQ(brValue, BX) @@ -430,8 +441,6 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual) MOVQ(nBits, CX.As64()) NEGQ(CX.As64()) SHRQ(CX, BX) - TESTQ(nBits, nBits) - CMOVQEQ(nBits, BX) } return BX } diff --git a/zstd/seqdec_amd64.go b/zstd/seqdec_amd64.go index b6832ee257..de9787b4d5 100644 --- a/zstd/seqdec_amd64.go +++ b/zstd/seqdec_amd64.go @@ -30,23 +30,13 @@ const errorMatchLenTooBig = 2 // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. // // Please refer to seqdec_generic.go for the reference implementation. +//go:noescape func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. +//go:noescape func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int -type sequenceDecs_decode_function = func(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int - -var sequenceDecs_decode sequenceDecs_decode_function - -func init() { - if cpuinfo.HasBMI2() { - sequenceDecs_decode = sequenceDecs_decode_bmi2 - } else { - sequenceDecs_decode = sequenceDecs_decode_amd64 - } -} - // decode sequences from the stream without the provided history. func (s *sequenceDecs) decode(seqs []seqVals) error { br := s.br @@ -70,7 +60,12 @@ func (s *sequenceDecs) decode(seqs []seqVals) error { s.seqSize = 0 - errCode := sequenceDecs_decode(s, br, &ctx) + var errCode int + if cpuinfo.HasBMI2() { + errCode = sequenceDecs_decode_bmi2(s, br, &ctx) + } else { + errCode = sequenceDecs_decode_amd64(s, br, &ctx) + } if errCode != 0 { i := len(seqs) - ctx.iteration switch errCode { diff --git a/zstd/seqdec_amd64.s b/zstd/seqdec_amd64.s index d3d6fc9863..1a4fbad918 100644 --- a/zstd/seqdec_amd64.s +++ b/zstd/seqdec_amd64.s @@ -142,18 +142,20 @@ sequenceDecs_decode_amd64_fill_3_byte_by_byte: JMP sequenceDecs_decode_amd64_fill_3_byte_by_byte sequenceDecs_decode_amd64_fill_3_end: - MOVQ R11, (SP) - MOVQ R9, AX - MOVQ ctx+16(FP), CX - CMPQ 96(CX), $0x00 - JZ sequenceDecs_decode_amd64_skip_update + MOVQ R11, (SP) + MOVQ R9, AX + SHRQ $0x08, AX + MOVBQZX AL, AX + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_amd64_skip_update // Update Literal Length State MOVBQZX DI, R11 SHRQ $0x10, DI MOVWQZX DI, DI CMPQ R11, $0x00 - JZ sequenceDecs_decode_amd64_llState_updateState_skip + JZ sequenceDecs_decode_amd64_llState_updateState_skip_zero MOVQ BX, CX ADDQ R11, BX MOVQ DX, R12 @@ -161,11 +163,9 @@ sequenceDecs_decode_amd64_fill_3_end: MOVQ R11, CX NEGQ CX SHRQ CL, R12 - TESTQ R11, R11 - CMOVQEQ R11, R12 ADDQ R12, DI -sequenceDecs_decode_amd64_llState_updateState_skip: +sequenceDecs_decode_amd64_llState_updateState_skip_zero: // Load ctx.llTable MOVQ ctx+16(FP), CX MOVQ (CX), CX @@ -176,7 +176,7 @@ sequenceDecs_decode_amd64_llState_updateState_skip: SHRQ $0x10, R8 MOVWQZX R8, R8 CMPQ R11, $0x00 - JZ sequenceDecs_decode_amd64_mlState_updateState_skip + JZ sequenceDecs_decode_amd64_mlState_updateState_skip_zero MOVQ BX, CX ADDQ R11, BX MOVQ DX, R12 @@ -184,11 +184,9 @@ sequenceDecs_decode_amd64_llState_updateState_skip: MOVQ R11, CX NEGQ CX SHRQ CL, R12 - TESTQ R11, R11 - CMOVQEQ R11, R12 ADDQ R12, R8 -sequenceDecs_decode_amd64_mlState_updateState_skip: +sequenceDecs_decode_amd64_mlState_updateState_skip_zero: // Load ctx.mlTable MOVQ ctx+16(FP), CX MOVQ 24(CX), CX @@ -199,7 +197,7 @@ sequenceDecs_decode_amd64_mlState_updateState_skip: SHRQ $0x10, R9 MOVWQZX R9, R9 CMPQ R11, $0x00 - JZ sequenceDecs_decode_amd64_ofState_updateState_skip + JZ sequenceDecs_decode_amd64_ofState_updateState_skip_zero MOVQ BX, CX ADDQ R11, BX MOVQ DX, R12 @@ -207,20 +205,15 @@ sequenceDecs_decode_amd64_mlState_updateState_skip: MOVQ R11, CX NEGQ CX SHRQ CL, R12 - TESTQ R11, R11 - CMOVQEQ R11, R12 ADDQ R12, R9 -sequenceDecs_decode_amd64_ofState_updateState_skip: +sequenceDecs_decode_amd64_ofState_updateState_skip_zero: // Load ctx.ofTable MOVQ ctx+16(FP), CX MOVQ 48(CX), CX MOVQ (CX)(R9*8), R9 sequenceDecs_decode_amd64_skip_update: - SHRQ $0x08, AX - MOVBQZX AL, AX - // Adjust offset MOVQ s+0(FP), CX MOVQ 16(R10), R11 @@ -444,16 +437,17 @@ sequenceDecs_decode_bmi2_fill_3_byte_by_byte: JMP sequenceDecs_decode_bmi2_fill_3_byte_by_byte sequenceDecs_decode_bmi2_fill_3_end: - MOVQ R10, (SP) - MOVQ R8, R10 - MOVQ ctx+16(FP), CX - CMPQ 96(CX), $0x00 - JZ sequenceDecs_decode_bmi2_skip_update + MOVQ R10, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R10 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_bmi2_skip_update // Update Literal Length State MOVBQZX SI, R11 - SHRQ $0x10, SI - MOVWQZX SI, SI + MOVQ $0x00001010, CX + BEXTRQ CX, SI, SI LEAQ (DX)(R11*1), CX MOVQ AX, R12 MOVQ CX, DX @@ -468,8 +462,8 @@ sequenceDecs_decode_bmi2_fill_3_end: // Update Match Length State MOVBQZX DI, R11 - SHRQ $0x10, DI - MOVWQZX DI, DI + MOVQ $0x00001010, CX + BEXTRQ CX, DI, DI LEAQ (DX)(R11*1), CX MOVQ AX, R12 MOVQ CX, DX @@ -484,8 +478,8 @@ sequenceDecs_decode_bmi2_fill_3_end: // Update Offset State MOVBQZX R8, R11 - SHRQ $0x10, R8 - MOVWQZX R8, R8 + MOVQ $0x00001010, CX + BEXTRQ CX, R8, R8 LEAQ (DX)(R11*1), CX MOVQ AX, R12 MOVQ CX, DX @@ -499,9 +493,6 @@ sequenceDecs_decode_bmi2_fill_3_end: MOVQ (CX)(R8*8), R8 sequenceDecs_decode_bmi2_skip_update: - SHRQ $0x08, R10 - MOVBQZX R10, R10 - // Adjust offset MOVQ s+0(FP), CX MOVQ 16(R9), R11 diff --git a/zstd/seqdec_amd64_test.go b/zstd/seqdec_amd64_test.go new file mode 100644 index 0000000000..06983e58f4 --- /dev/null +++ b/zstd/seqdec_amd64_test.go @@ -0,0 +1,155 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +package zstd + +import ( + "bytes" + "encoding/csv" + "fmt" + "io/ioutil" + "os" + "reflect" + "strconv" + "testing" + + "github.com/klauspost/compress/internal/cpuinfo" + "github.com/klauspost/compress/zip" +) + +func Benchmark_seqdec_decodeNoBMI(b *testing.B) { + if !cpuinfo.HasBMI2() { + b.Skip("Already tested, platform does not have bmi2") + return + } + defer cpuinfo.DisableBMI2()() + + benchmark_seqdec_decode(b) +} + +func Test_sequenceDecs_decodeNoBMI(t *testing.T) { + if !cpuinfo.HasBMI2() { + t.Skip("Already tested, platform does not have bmi2") + return + } + defer cpuinfo.DisableBMI2()() + + const writeWant = false + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + want := map[string][]seqVals{} + var wantOffsets = map[string][3]int{} + if !writeWant { + fn := "testdata/seqs-want.zip" + data, err := ioutil.ReadFile(fn) + tb := t + if err != nil { + tb.Fatal(err) + } + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + tb.Fatal(err) + } + for _, tt := range zr.File { + var ref testSequence + if !ref.parse(tt.Name) { + tb.Skip("unable to parse:", tt.Name) + } + o, err := tt.Open() + if err != nil { + t.Fatal(err) + } + r := csv.NewReader(o) + recs, err := r.ReadAll() + if err != nil { + t.Fatal(err) + } + for i, rec := range recs { + if i == 0 { + var o [3]int + o[0], _ = strconv.Atoi(rec[0]) + o[1], _ = strconv.Atoi(rec[1]) + o[2], _ = strconv.Atoi(rec[2]) + wantOffsets[tt.Name] = o + continue + } + s := seqVals{} + s.mo, _ = strconv.Atoi(rec[0]) + s.ml, _ = strconv.Atoi(rec[1]) + s.ll, _ = strconv.Atoi(rec[2]) + want[tt.Name] = append(want[tt.Name], s) + } + o.Close() + } + } + fn := "testdata/seqs.zip" + data, err := ioutil.ReadFile(fn) + tb := t + if err != nil { + tb.Fatal(err) + } + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + tb.Fatal(err) + } + for _, tt := range zr.File { + var ref testSequence + if !ref.parse(tt.Name) { + tb.Skip("unable to parse:", tt.Name) + } + r, err := tt.Open() + if err != nil { + tb.Error(err) + return + } + + seqData, err := ioutil.ReadAll(r) + if err != nil { + tb.Error(err) + return + } + var buf = bytes.NewBuffer(seqData) + s := readDecoders(tb, buf, ref) + seqs := make([]seqVals, ref.n) + + t.Run(tt.Name, func(t *testing.T) { + fatalIf := func(err error) { + if err != nil { + t.Fatal(err) + } + } + fatalIf(s.br.init(buf.Bytes())) + fatalIf(s.litLengths.init(s.br)) + fatalIf(s.offsets.init(s.br)) + fatalIf(s.matchLengths.init(s.br)) + + err := s.decode(seqs) + if err != nil { + t.Error(err) + } + if writeWant { + w, err := zw.Create(tt.Name) + fatalIf(err) + c := csv.NewWriter(w) + w.Write([]byte(fmt.Sprintf("%d,%d,%d\n", s.prevOffset[0], s.prevOffset[1], s.prevOffset[2]))) + for _, seq := range seqs { + c.Write([]string{strconv.Itoa(seq.mo), strconv.Itoa(seq.ml), strconv.Itoa(seq.ll)}) + } + c.Flush() + } else { + if s.prevOffset != wantOffsets[tt.Name] { + t.Errorf("want offsets %v, got %v", wantOffsets[tt.Name], s.prevOffset) + } + + if !reflect.DeepEqual(want[tt.Name], seqs) { + t.Errorf("got %v\nwant %v", seqs, want[tt.Name]) + } + } + }) + } + if writeWant { + zw.Close() + ioutil.WriteFile("testdata/seqs-want.zip", buf.Bytes(), os.ModePerm) + } +} diff --git a/zstd/seqdec_test.go b/zstd/seqdec_test.go index 0737b69e09..d20115b323 100644 --- a/zstd/seqdec_test.go +++ b/zstd/seqdec_test.go @@ -404,6 +404,10 @@ func Test_seqdec_decodeSync(t *testing.T) { } func Benchmark_seqdec_decode(b *testing.B) { + benchmark_seqdec_decode(b) +} + +func benchmark_seqdec_decode(b *testing.B) { fn := "testdata/seqs.zip" data, err := ioutil.ReadFile(fn) tb := b