From 6c809acb72f5b1c6e01d59a6ce45bbdacb0f31e1 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 18 Mar 2022 07:06:19 -0700 Subject: [PATCH] zstd: Asm decoder tweaks (#537) * Add non-bmi amd64 tests * Use BEXTRQ for extracting shifted values. * Move 0 check into getBits. * Remove ctx alloc. Sequences only, BMI: ``` benchmark old ns/op new ns/op delta Benchmark_seqdec_decode/n-12286-lits-13914-prev-9869-1990358-3296656-win-4194304.blk-32 91657 91114 -0.59% Benchmark_seqdec_decode/n-12485-lits-6960-prev-976039-2250252-2463561-win-4194304.blk-32 92392 90416 -2.14% Benchmark_seqdec_decode/n-14746-lits-14461-prev-209-8-1379909-win-4194304.blk-32 83022 79745 -3.95% Benchmark_seqdec_decode/n-1525-lits-1498-prev-2009476-797934-2994405-win-4194304.blk-32 9149 8856 -3.20% Benchmark_seqdec_decode/n-3478-lits-3628-prev-895243-2104056-2119329-win-4194304.blk-32 22402 22102 -1.34% Benchmark_seqdec_decode/n-8422-lits-5840-prev-168095-2298675-433830-win-4194304.blk-32 60844 60114 -1.20% Benchmark_seqdec_decode/n-1000-lits-1057-prev-21887-92-217-win-8388608.blk-32 5785 5879 +1.62% Benchmark_seqdec_decode/n-15134-lits-20798-prev-4882976-4884216-4474622-win-8388608.blk-32 118030 115597 -2.06% Benchmark_seqdec_decode/n-2-lits-0-prev-620601-689171-848-win-8388608.blk-32 135 64.3 -52.35% Benchmark_seqdec_decode/n-90-lits-67-prev-19498-23-19710-win-8388608.blk-32 648 589 -9.03% Benchmark_seqdec_decode/n-931-lits-1179-prev-36502-1526-1518-win-8388608.blk-32 5555 5467 -1.58% Benchmark_seqdec_decode/n-2898-lits-4062-prev-335-386-751-win-8388608.blk-32 17896 17605 -1.63% Benchmark_seqdec_decode/n-4056-lits-12419-prev-10792-66-309849-win-8388608.blk-32 27457 27232 -0.82% Benchmark_seqdec_decode/n-8028-lits-4568-prev-917-65-920-win-8388608.blk-32 59341 58158 -1.99% ``` No BMI: ``` benchmark old ns/op new ns/op delta Benchmark_seqdec_decodeNoBMI/n-12286-lits-13914-prev-9869-1990358-3296656-win-4194304.blk-32 114889 113333 -1.35% Benchmark_seqdec_decodeNoBMI/n-12485-lits-6960-prev-976039-2250252-2463561-win-4194304.blk-32 121269 119500 -1.46% Benchmark_seqdec_decodeNoBMI/n-14746-lits-14461-prev-209-8-1379909-win-4194304.blk-32 106986 102585 -4.11% Benchmark_seqdec_decodeNoBMI/n-1525-lits-1498-prev-2009476-797934-2994405-win-4194304.blk-32 10910 10304 -5.55% Benchmark_seqdec_decodeNoBMI/n-3478-lits-3628-prev-895243-2104056-2119329-win-4194304.blk-32 25965 24642 -5.10% Benchmark_seqdec_decodeNoBMI/n-8422-lits-5840-prev-168095-2298675-433830-win-4194304.blk-32 80183 77980 -2.75% Benchmark_seqdec_decodeNoBMI/n-1000-lits-1057-prev-21887-92-217-win-8388608.blk-32 6702 6369 -4.97% Benchmark_seqdec_decodeNoBMI/n-15134-lits-20798-prev-4882976-4884216-4474622-win-8388608.blk-32 151867 148752 -2.05% Benchmark_seqdec_decodeNoBMI/n-2-lits-0-prev-620601-689171-848-win-8388608.blk-32 139 46.8 -66.31% Benchmark_seqdec_decodeNoBMI/n-90-lits-67-prev-19498-23-19710-win-8388608.blk-32 744 609 -18.13% Benchmark_seqdec_decodeNoBMI/n-931-lits-1179-prev-36502-1526-1518-win-8388608.blk-32 6570 6083 -7.41% Benchmark_seqdec_decodeNoBMI/n-2898-lits-4062-prev-335-386-751-win-8388608.blk-32 20448 19955 -2.41% Benchmark_seqdec_decodeNoBMI/n-4056-lits-12419-prev-10792-66-309849-win-8388608.blk-32 34177 32790 -4.06% Benchmark_seqdec_decodeNoBMI/n-8028-lits-4568-prev-917-65-920-win-8388608.blk-32 77864 75628 -2.87% ``` --- internal/cpuinfo/cpuinfo.go | 10 +++ zstd/_generate/gen.go | 41 ++++++---- zstd/seqdec_amd64.go | 21 ++--- zstd/seqdec_amd64.s | 59 ++++++-------- zstd/seqdec_amd64_test.go | 155 ++++++++++++++++++++++++++++++++++++ zstd/seqdec_test.go | 4 + 6 files changed, 227 insertions(+), 63 deletions(-) create mode 100644 zstd/seqdec_amd64_test.go diff --git a/internal/cpuinfo/cpuinfo.go b/internal/cpuinfo/cpuinfo.go index f85345cc0c..3954c51219 100644 --- a/internal/cpuinfo/cpuinfo.go +++ b/internal/cpuinfo/cpuinfo.go @@ -15,6 +15,16 @@ func HasBMI2() bool { return hasBMI2 } +// DisableBMI2 will disable BMI2, for testing purposes. +// Call returned function to restore previous state. +func DisableBMI2() func() { + old := hasBMI2 + hasBMI2 = false + return func() { + hasBMI2 = old + } +} + // HasBMI checks whether an x86 CPU supports both BMI1 and BMI2 extensions. func HasBMI() bool { return HasBMI1() && HasBMI2() diff --git a/zstd/_generate/gen.go b/zstd/_generate/gen.go index 084513ee6c..15d5bce525 100644 --- a/zstd/_generate/gen.go +++ b/zstd/_generate/gen.go @@ -165,7 +165,16 @@ func (o options) genDecodeSeqAsm(name string) { } R14 := GP64() - MOVQ(ofState, R14) // copy ofState, its current value is needed below + if o.bmi2 { + tmp := GP64() + MOVQ(U32(8|(8<<8)), tmp) + BEXTRQ(tmp, ofState, R14) + } else { + MOVQ(ofState, R14) // copy ofState, its current value is needed below + SHRQ(U8(8), R14) // moB (from the ofState before its update) + MOVBQZX(R14.As8(), R14) + } + // Reload ctx ctx := Dereference(Param("ctx")) iteration, err := ctx.Field("iteration").Resolve() @@ -188,8 +197,6 @@ func (o options) genDecodeSeqAsm(name string) { Label(name + "_skip_update") // mo = s.adjustOffset(mo, ll, moB) - SHRQ(U8(8), R14) // moB (from the ofState before its update) - MOVBQZX(R14.As8(), R14) Comment("Adjust offset") @@ -373,27 +380,27 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu }) DX := GP64() - MOVQ(state, DX) // TODO: maybe use BEXTR? - SHRQ(U8(16), DX) - MOVWQZX(DX.As16(), DX) - - if !o.bmi2 { - // TODO: Probably reasonable to kip if AX==0s - CMPQ(AX, U8(0)) - JZ(LabelRef(name + "_skip")) + if o.bmi2 { + tmp := GP64() + MOVQ(U32(16|(16<<8)), tmp) + BEXTRQ(tmp, state, DX) + } else { + MOVQ(state, DX) + SHRQ(U8(16), DX) + MOVWQZX(DX.As16(), DX) } { - lowBits := o.getBits(name+"_getBits", AX, brValue, brBitsRead) + lowBits := o.getBits(name+"_getBits", AX, brValue, brBitsRead, LabelRef(name+"_skip_zero")) // Check if below tablelog assert(func(ok LabelRef) { CMPQ(lowBits, U32(512)) JB(ok) }) ADDQ(lowBits, DX) + Label(name + "_skip_zero") } - Label(name + "_skip") // Load table pointer tablePtr := GP64() Comment("Load ctx." + table) @@ -413,7 +420,9 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state) } -func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual { +// getBits will return nbits bits from brValue. +// If nbits == 0 it *may* jump to jmpZero, otherwise 0 is returned. +func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual, jmpZero LabelRef) reg.GPVirtual { BX := GP64() CX := reg.CL if o.bmi2 { @@ -423,6 +432,8 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual) ROLQ(CX, BX) BZHIQ(nBits, BX, BX) } else { + CMPQ(nBits, U8(0)) + JZ(jmpZero) MOVQ(brBitsRead, CX.As64()) ADDQ(nBits, brBitsRead) MOVQ(brValue, BX) @@ -430,8 +441,6 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual) MOVQ(nBits, CX.As64()) NEGQ(CX.As64()) SHRQ(CX, BX) - TESTQ(nBits, nBits) - CMOVQEQ(nBits, BX) } return BX } diff --git a/zstd/seqdec_amd64.go b/zstd/seqdec_amd64.go index b6832ee257..de9787b4d5 100644 --- a/zstd/seqdec_amd64.go +++ b/zstd/seqdec_amd64.go @@ -30,23 +30,13 @@ const errorMatchLenTooBig = 2 // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. // // Please refer to seqdec_generic.go for the reference implementation. +//go:noescape func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. +//go:noescape func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int -type sequenceDecs_decode_function = func(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int - -var sequenceDecs_decode sequenceDecs_decode_function - -func init() { - if cpuinfo.HasBMI2() { - sequenceDecs_decode = sequenceDecs_decode_bmi2 - } else { - sequenceDecs_decode = sequenceDecs_decode_amd64 - } -} - // decode sequences from the stream without the provided history. func (s *sequenceDecs) decode(seqs []seqVals) error { br := s.br @@ -70,7 +60,12 @@ func (s *sequenceDecs) decode(seqs []seqVals) error { s.seqSize = 0 - errCode := sequenceDecs_decode(s, br, &ctx) + var errCode int + if cpuinfo.HasBMI2() { + errCode = sequenceDecs_decode_bmi2(s, br, &ctx) + } else { + errCode = sequenceDecs_decode_amd64(s, br, &ctx) + } if errCode != 0 { i := len(seqs) - ctx.iteration switch errCode { diff --git a/zstd/seqdec_amd64.s b/zstd/seqdec_amd64.s index d3d6fc9863..1a4fbad918 100644 --- a/zstd/seqdec_amd64.s +++ b/zstd/seqdec_amd64.s @@ -142,18 +142,20 @@ sequenceDecs_decode_amd64_fill_3_byte_by_byte: JMP sequenceDecs_decode_amd64_fill_3_byte_by_byte sequenceDecs_decode_amd64_fill_3_end: - MOVQ R11, (SP) - MOVQ R9, AX - MOVQ ctx+16(FP), CX - CMPQ 96(CX), $0x00 - JZ sequenceDecs_decode_amd64_skip_update + MOVQ R11, (SP) + MOVQ R9, AX + SHRQ $0x08, AX + MOVBQZX AL, AX + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_amd64_skip_update // Update Literal Length State MOVBQZX DI, R11 SHRQ $0x10, DI MOVWQZX DI, DI CMPQ R11, $0x00 - JZ sequenceDecs_decode_amd64_llState_updateState_skip + JZ sequenceDecs_decode_amd64_llState_updateState_skip_zero MOVQ BX, CX ADDQ R11, BX MOVQ DX, R12 @@ -161,11 +163,9 @@ sequenceDecs_decode_amd64_fill_3_end: MOVQ R11, CX NEGQ CX SHRQ CL, R12 - TESTQ R11, R11 - CMOVQEQ R11, R12 ADDQ R12, DI -sequenceDecs_decode_amd64_llState_updateState_skip: +sequenceDecs_decode_amd64_llState_updateState_skip_zero: // Load ctx.llTable MOVQ ctx+16(FP), CX MOVQ (CX), CX @@ -176,7 +176,7 @@ sequenceDecs_decode_amd64_llState_updateState_skip: SHRQ $0x10, R8 MOVWQZX R8, R8 CMPQ R11, $0x00 - JZ sequenceDecs_decode_amd64_mlState_updateState_skip + JZ sequenceDecs_decode_amd64_mlState_updateState_skip_zero MOVQ BX, CX ADDQ R11, BX MOVQ DX, R12 @@ -184,11 +184,9 @@ sequenceDecs_decode_amd64_llState_updateState_skip: MOVQ R11, CX NEGQ CX SHRQ CL, R12 - TESTQ R11, R11 - CMOVQEQ R11, R12 ADDQ R12, R8 -sequenceDecs_decode_amd64_mlState_updateState_skip: +sequenceDecs_decode_amd64_mlState_updateState_skip_zero: // Load ctx.mlTable MOVQ ctx+16(FP), CX MOVQ 24(CX), CX @@ -199,7 +197,7 @@ sequenceDecs_decode_amd64_mlState_updateState_skip: SHRQ $0x10, R9 MOVWQZX R9, R9 CMPQ R11, $0x00 - JZ sequenceDecs_decode_amd64_ofState_updateState_skip + JZ sequenceDecs_decode_amd64_ofState_updateState_skip_zero MOVQ BX, CX ADDQ R11, BX MOVQ DX, R12 @@ -207,20 +205,15 @@ sequenceDecs_decode_amd64_mlState_updateState_skip: MOVQ R11, CX NEGQ CX SHRQ CL, R12 - TESTQ R11, R11 - CMOVQEQ R11, R12 ADDQ R12, R9 -sequenceDecs_decode_amd64_ofState_updateState_skip: +sequenceDecs_decode_amd64_ofState_updateState_skip_zero: // Load ctx.ofTable MOVQ ctx+16(FP), CX MOVQ 48(CX), CX MOVQ (CX)(R9*8), R9 sequenceDecs_decode_amd64_skip_update: - SHRQ $0x08, AX - MOVBQZX AL, AX - // Adjust offset MOVQ s+0(FP), CX MOVQ 16(R10), R11 @@ -444,16 +437,17 @@ sequenceDecs_decode_bmi2_fill_3_byte_by_byte: JMP sequenceDecs_decode_bmi2_fill_3_byte_by_byte sequenceDecs_decode_bmi2_fill_3_end: - MOVQ R10, (SP) - MOVQ R8, R10 - MOVQ ctx+16(FP), CX - CMPQ 96(CX), $0x00 - JZ sequenceDecs_decode_bmi2_skip_update + MOVQ R10, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R10 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_bmi2_skip_update // Update Literal Length State MOVBQZX SI, R11 - SHRQ $0x10, SI - MOVWQZX SI, SI + MOVQ $0x00001010, CX + BEXTRQ CX, SI, SI LEAQ (DX)(R11*1), CX MOVQ AX, R12 MOVQ CX, DX @@ -468,8 +462,8 @@ sequenceDecs_decode_bmi2_fill_3_end: // Update Match Length State MOVBQZX DI, R11 - SHRQ $0x10, DI - MOVWQZX DI, DI + MOVQ $0x00001010, CX + BEXTRQ CX, DI, DI LEAQ (DX)(R11*1), CX MOVQ AX, R12 MOVQ CX, DX @@ -484,8 +478,8 @@ sequenceDecs_decode_bmi2_fill_3_end: // Update Offset State MOVBQZX R8, R11 - SHRQ $0x10, R8 - MOVWQZX R8, R8 + MOVQ $0x00001010, CX + BEXTRQ CX, R8, R8 LEAQ (DX)(R11*1), CX MOVQ AX, R12 MOVQ CX, DX @@ -499,9 +493,6 @@ sequenceDecs_decode_bmi2_fill_3_end: MOVQ (CX)(R8*8), R8 sequenceDecs_decode_bmi2_skip_update: - SHRQ $0x08, R10 - MOVBQZX R10, R10 - // Adjust offset MOVQ s+0(FP), CX MOVQ 16(R9), R11 diff --git a/zstd/seqdec_amd64_test.go b/zstd/seqdec_amd64_test.go new file mode 100644 index 0000000000..06983e58f4 --- /dev/null +++ b/zstd/seqdec_amd64_test.go @@ -0,0 +1,155 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +package zstd + +import ( + "bytes" + "encoding/csv" + "fmt" + "io/ioutil" + "os" + "reflect" + "strconv" + "testing" + + "github.com/klauspost/compress/internal/cpuinfo" + "github.com/klauspost/compress/zip" +) + +func Benchmark_seqdec_decodeNoBMI(b *testing.B) { + if !cpuinfo.HasBMI2() { + b.Skip("Already tested, platform does not have bmi2") + return + } + defer cpuinfo.DisableBMI2()() + + benchmark_seqdec_decode(b) +} + +func Test_sequenceDecs_decodeNoBMI(t *testing.T) { + if !cpuinfo.HasBMI2() { + t.Skip("Already tested, platform does not have bmi2") + return + } + defer cpuinfo.DisableBMI2()() + + const writeWant = false + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + want := map[string][]seqVals{} + var wantOffsets = map[string][3]int{} + if !writeWant { + fn := "testdata/seqs-want.zip" + data, err := ioutil.ReadFile(fn) + tb := t + if err != nil { + tb.Fatal(err) + } + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + tb.Fatal(err) + } + for _, tt := range zr.File { + var ref testSequence + if !ref.parse(tt.Name) { + tb.Skip("unable to parse:", tt.Name) + } + o, err := tt.Open() + if err != nil { + t.Fatal(err) + } + r := csv.NewReader(o) + recs, err := r.ReadAll() + if err != nil { + t.Fatal(err) + } + for i, rec := range recs { + if i == 0 { + var o [3]int + o[0], _ = strconv.Atoi(rec[0]) + o[1], _ = strconv.Atoi(rec[1]) + o[2], _ = strconv.Atoi(rec[2]) + wantOffsets[tt.Name] = o + continue + } + s := seqVals{} + s.mo, _ = strconv.Atoi(rec[0]) + s.ml, _ = strconv.Atoi(rec[1]) + s.ll, _ = strconv.Atoi(rec[2]) + want[tt.Name] = append(want[tt.Name], s) + } + o.Close() + } + } + fn := "testdata/seqs.zip" + data, err := ioutil.ReadFile(fn) + tb := t + if err != nil { + tb.Fatal(err) + } + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + tb.Fatal(err) + } + for _, tt := range zr.File { + var ref testSequence + if !ref.parse(tt.Name) { + tb.Skip("unable to parse:", tt.Name) + } + r, err := tt.Open() + if err != nil { + tb.Error(err) + return + } + + seqData, err := ioutil.ReadAll(r) + if err != nil { + tb.Error(err) + return + } + var buf = bytes.NewBuffer(seqData) + s := readDecoders(tb, buf, ref) + seqs := make([]seqVals, ref.n) + + t.Run(tt.Name, func(t *testing.T) { + fatalIf := func(err error) { + if err != nil { + t.Fatal(err) + } + } + fatalIf(s.br.init(buf.Bytes())) + fatalIf(s.litLengths.init(s.br)) + fatalIf(s.offsets.init(s.br)) + fatalIf(s.matchLengths.init(s.br)) + + err := s.decode(seqs) + if err != nil { + t.Error(err) + } + if writeWant { + w, err := zw.Create(tt.Name) + fatalIf(err) + c := csv.NewWriter(w) + w.Write([]byte(fmt.Sprintf("%d,%d,%d\n", s.prevOffset[0], s.prevOffset[1], s.prevOffset[2]))) + for _, seq := range seqs { + c.Write([]string{strconv.Itoa(seq.mo), strconv.Itoa(seq.ml), strconv.Itoa(seq.ll)}) + } + c.Flush() + } else { + if s.prevOffset != wantOffsets[tt.Name] { + t.Errorf("want offsets %v, got %v", wantOffsets[tt.Name], s.prevOffset) + } + + if !reflect.DeepEqual(want[tt.Name], seqs) { + t.Errorf("got %v\nwant %v", seqs, want[tt.Name]) + } + } + }) + } + if writeWant { + zw.Close() + ioutil.WriteFile("testdata/seqs-want.zip", buf.Bytes(), os.ModePerm) + } +} diff --git a/zstd/seqdec_test.go b/zstd/seqdec_test.go index 0737b69e09..d20115b323 100644 --- a/zstd/seqdec_test.go +++ b/zstd/seqdec_test.go @@ -404,6 +404,10 @@ func Test_seqdec_decodeSync(t *testing.T) { } func Benchmark_seqdec_decode(b *testing.B) { + benchmark_seqdec_decode(b) +} + +func benchmark_seqdec_decode(b *testing.B) { fn := "testdata/seqs.zip" data, err := ioutil.ReadFile(fn) tb := b