diff --git a/huff0/bitwriter.go b/huff0/bitwriter.go index bda4021efd..6bce4e87d4 100644 --- a/huff0/bitwriter.go +++ b/huff0/bitwriter.go @@ -43,6 +43,11 @@ func (b *bitWriter) addBits16Clean(value uint16, bits uint8) { func (b *bitWriter) encSymbol(ct cTable, symbol byte) { enc := ct[symbol] b.bitContainer |= uint64(enc.val) << (b.nBits & 63) + if false { + if enc.nBits == 0 { + panic("nbits 0") + } + } b.nBits += enc.nBits } @@ -54,6 +59,14 @@ func (b *bitWriter) encTwoSymbols(ct cTable, av, bv byte) { sh := b.nBits & 63 combined := uint64(encA.val) | (uint64(encB.val) << (encA.nBits & 63)) b.bitContainer |= combined << sh + if false { + if encA.nBits == 0 { + panic("nbitsA 0") + } + if encB.nBits == 0 { + panic("nbitsB 0") + } + } b.nBits += encA.nBits + encB.nBits } diff --git a/huff0/compress.go b/huff0/compress.go index 0843cb014f..f9ed5f8306 100644 --- a/huff0/compress.go +++ b/huff0/compress.go @@ -77,8 +77,11 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error) // Each symbol present maximum once or too well distributed. return nil, false, ErrIncompressible } - - if s.Reuse == ReusePolicyPrefer && canReuse { + if s.Reuse == ReusePolicyMust && !canReuse { + // We must reuse, but we can't. + return nil, false, ErrIncompressible + } + if (s.Reuse == ReusePolicyPrefer || s.Reuse == ReusePolicyMust) && canReuse { keepTable := s.cTable keepTL := s.actualTableLog s.cTable = s.prevTable @@ -90,6 +93,9 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error) s.OutData = s.Out return s.Out, true, nil } + if s.Reuse == ReusePolicyMust { + return nil, false, ErrIncompressible + } // Do not attempt to re-use later. s.prevTable = s.prevTable[:0] } diff --git a/huff0/compress_test.go b/huff0/compress_test.go index ac264a5a18..f389e912aa 100644 --- a/huff0/compress_test.go +++ b/huff0/compress_test.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "math/rand" "path/filepath" + "reflect" "strings" "testing" @@ -265,6 +266,95 @@ func TestCompress1X(t *testing.T) { } } +func TestCompress1XMustReuse(t *testing.T) { + for _, test := range testfiles { + t.Run(test.name, func(t *testing.T) { + var s Scratch + buf0, err := test.fn() + if err != nil { + t.Fatal(err) + } + if len(buf0) > BlockSizeMax { + buf0 = buf0[:BlockSizeMax] + } + b, re, err := Compress1X(buf0, &s) + if err != test.err1X { + t.Errorf("want error %v (%T), got %v (%T)", test.err1X, test.err1X, err, err) + } + if err != nil { + t.Log(test.name, err.Error()) + return + } + if b == nil { + t.Error("got no output") + return + } + + min := s.minSize(len(buf0)) + if len(s.OutData) < min { + t.Errorf("output data length (%d) below shannon limit (%d)", len(s.OutData), min) + } + if len(s.OutTable) == 0 { + t.Error("got no table definition") + } + if re { + t.Error("claimed to have re-used.") + } + if len(s.OutData) == 0 { + t.Error("got no data output") + } + t.Logf("%s: %d -> %d bytes (%.2f:1) re:%t (table: %d bytes)", test.name, len(buf0), len(b), float64(len(buf0))/float64(len(b)), re, len(s.OutTable)) + table := s.OutTable + prevTable := s.prevTable + for i, v := range prevTable { + // Clear unused sections for comparison + if v.nBits == 0 { + prevTable[i].val = 0 + } + } + b = s.OutData + actl := s.actualTableLog + + // Use only the table data to recompress. + s = Scratch{} + s2 := &s + s.Reuse = ReusePolicyMust + s2, _, err = ReadTable(table, s2) + if err != nil { + t.Error("Could not read table", err) + return + } + if !reflect.DeepEqual(prevTable, s2.prevTable) { + t.Errorf("prevtable mismatch.\ngot %v\nwant %v", s2.prevTable, prevTable) + } + if actl != s.actualTableLog { + t.Errorf("tablelog mismatch, want %d, got %d", actl, s.actualTableLog) + } + b2, reused, err := Compress1X(buf0, s2) + if err != nil { + t.Error("Could not re-compress with prev table", err) + } + if !reused { + t.Error("didn't reuse...") + return + } + if len(b2) != len(b) { + t.Errorf("recompressed to different size, want %d, got %d", len(b), len(b2)) + return + } + + if !bytes.Equal(b, b2) { + for i := range b { + if b[i] != b2[i] { + t.Errorf("recompressed to different output. First mismatch at byte %d, (want %x != got %x)", i, b[i], b2[i]) + return + } + } + } + }) + } +} + func TestCompress4X(t *testing.T) { for _, test := range testfiles { t.Run(test.name, func(t *testing.T) { diff --git a/huff0/decompress.go b/huff0/decompress.go index a03b2634af..41703bba4d 100644 --- a/huff0/decompress.go +++ b/huff0/decompress.go @@ -32,7 +32,7 @@ const use8BitTables = true // The size of the input may be larger than the table definition. // Any content remaining after the table definition will be returned. // If no Scratch is provided a new one is allocated. -// The returned Scratch can be used for decoding input using this table. +// The returned Scratch can be used for encoding or decoding input using this table. func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { s, err = s.prepare(in) if err != nil { @@ -58,8 +58,8 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { s.symbolLen = uint16(oSize) in = in[iSize:] } else { - if len(in) <= int(iSize) { - return s, nil, errors.New("input too small for table") + if len(in) < int(iSize) { + return s, nil, fmt.Errorf("input too small for table, want %d bytes, have %d", iSize, len(in)) } // FSE compressed weights s.fse.DecompressLimit = 255 @@ -138,15 +138,33 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { if len(s.dt.single) != tSize { s.dt.single = make([]dEntrySingle, tSize) } + cTable := s.prevTable + if cap(cTable) < maxSymbolValue+1 { + cTable = make([]cTableEntry, 0, maxSymbolValue+1) + } + cTable = cTable[:maxSymbolValue+1] + s.prevTable = cTable[:s.symbolLen] + s.prevTableLog = s.actualTableLog + for n, w := range s.huffWeight[:s.symbolLen] { if w == 0 { + cTable[n] = cTableEntry{ + val: 0, + nBits: 0, + } continue } length := (uint32(1) << w) >> 1 d := dEntrySingle{ entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8), } + rank := &rankStats[w] + cTable[n] = cTableEntry{ + val: uint16(*rank >> (w - 1)), + nBits: uint8(d.entry), + } + single := s.dt.single[*rank : *rank+length] for i := range single { single[i] = d diff --git a/huff0/huff0.go b/huff0/huff0.go index 177d6c4ea0..5dd66854b0 100644 --- a/huff0/huff0.go +++ b/huff0/huff0.go @@ -55,6 +55,9 @@ const ( // ReusePolicyNone will disable re-use of tables. // This is slightly faster than ReusePolicyAllow but may produce larger output. ReusePolicyNone + + // ReusePolicyMust must allow reuse and produce smaller output. + ReusePolicyMust ) type Scratch struct {