Skip to content

Commit

Permalink
Remove writer reference on close (#224)
Browse files Browse the repository at this point in the history
Furthermore:

* Clean up dictionary handling & test.
* Fixes dictionary use in levels 1-6.

Fixes #219
  • Loading branch information
klauspost authored Feb 16, 2020
1 parent b949da4 commit 59173b5
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 64 deletions.
35 changes: 13 additions & 22 deletions flate/deflate.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ type compressionLevel struct {
// See https://blog.klauspost.com/rebalancing-deflate-compression-levels/
var levels = []compressionLevel{
{}, // 0
// Level 1-4 uses specialized algorithm - values not used
// Level 1-6 uses specialized algorithm - values not used
{0, 0, 0, 0, 0, 1},
{0, 0, 0, 0, 0, 2},
{0, 0, 0, 0, 0, 3},
{0, 0, 0, 0, 0, 4},
// For levels 5-6 we don't bother trying with lazy matches.
// Lazy matching is at least 30% slower, with 1.5% increase.
{6, 0, 12, 8, 12, 5},
{8, 0, 24, 16, 16, 6},
{0, 0, 0, 0, 0, 5},
{0, 0, 0, 0, 0, 6},
// Levels 7-9 use increasingly more lazy matching
// and increasingly stringent conditions for "good enough".
{8, 8, 24, 16, skipNever, 7},
Expand Down Expand Up @@ -203,9 +201,8 @@ func (d *compressor) writeBlockSkip(tok *tokens, index int, eof bool) error {
// This is much faster than doing a full encode.
// Should only be used after a start/reset.
func (d *compressor) fillWindow(b []byte) {
// Do not fill window if we are in store-only mode,
// use constant or Snappy compression.
if d.level == 0 {
// Do not fill window if we are in store-only or huffman mode.
if d.level <= 0 {
return
}
if d.fast != nil {
Expand Down Expand Up @@ -667,6 +664,7 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
default:
return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level)
}
d.level = level
return nil
}

Expand Down Expand Up @@ -720,6 +718,7 @@ func (d *compressor) close() error {
return d.w.err
}
d.w.flush()
d.w.reset(nil)
return d.w.err
}

Expand Down Expand Up @@ -750,8 +749,7 @@ func NewWriter(w io.Writer, level int) (*Writer, error) {
// can only be decompressed by a Reader initialized with the
// same dictionary.
func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
dw := &dictWriter{w}
zw, err := NewWriter(dw, level)
zw, err := NewWriter(w, level)
if err != nil {
return nil, err
}
Expand All @@ -760,14 +758,6 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
return zw, err
}

type dictWriter struct {
w io.Writer
}

func (w *dictWriter) Write(b []byte) (n int, err error) {
return w.w.Write(b)
}

// A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter).
type Writer struct {
Expand Down Expand Up @@ -805,11 +795,12 @@ func (w *Writer) Close() error {
// the result of NewWriter or NewWriterDict called with dst
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.writer.(*dictWriter); ok {
if len(w.dict) > 0 {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
w.d.fillWindow(w.dict)
w.d.reset(dst)
if dst != nil {
w.d.fillWindow(w.dict)
}
} else {
// w was created with NewWriter
w.d.reset(dst)
Expand Down
95 changes: 53 additions & 42 deletions flate/deflate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,54 +516,65 @@ func TestWriterReset(t *testing.T) {
t.Errorf("level %d Writer not reset after Reset", level)
}
}
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, NoCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, DefaultCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, BestCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, ConstantCompression) })
dict := []byte("we are the world")
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, NoCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, DefaultCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, BestCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, ConstantCompression, dict) })
}

func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) {
buf := new(bytes.Buffer)
w, err := newWriter(buf)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
b := []byte("hello world")
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out1 := buf.Bytes()

buf2 := new(bytes.Buffer)
w.Reset(buf2)
for i := 0; i < 1024; i++ {
w.Write(b)
for i := HuffmanOnly; i <= BestCompression; i++ {
testResetOutput(t, fmt.Sprint("level-", i), func(w io.Writer) (*Writer, error) { return NewWriter(w, i) })
}
w.Close()
out2 := buf2.Bytes()

if len(out1) != len(out2) {
t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
dict := []byte(strings.Repeat("we are the world - how are you?", 3))
for i := HuffmanOnly; i <= BestCompression; i++ {
testResetOutput(t, fmt.Sprint("dict-level-", i), func(w io.Writer) (*Writer, error) { return NewWriterDict(w, i, dict) })
}
if bytes.Compare(out1, out2) != 0 {
mm := 0
for i, b := range out1[:len(out2)] {
if b != out2[i] {
t.Errorf("mismatch index %d: %02x, expected %02x", i, out2[i], b)
for i := HuffmanOnly; i <= BestCompression; i++ {
testResetOutput(t, fmt.Sprint("dict-reset-level-", i), func(w io.Writer) (*Writer, error) {
w2, err := NewWriter(nil, i)
if err != nil {
return w2, err
}
mm++
if mm == 10 {
t.Fatal("Stopping")
w2.ResetDict(w, dict)
return w2, nil
})
}
}

func testResetOutput(t *testing.T, name string, newWriter func(w io.Writer) (*Writer, error)) {
t.Run(name, func(t *testing.T) {
buf := new(bytes.Buffer)
w, err := newWriter(buf)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
b := []byte("hello world - how are you doing?")
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out1 := buf.Bytes()

buf2 := new(bytes.Buffer)
w.Reset(buf2)
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out2 := buf2.Bytes()

if len(out1) != len(out2) {
t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
}
if bytes.Compare(out1, out2) != 0 {
mm := 0
for i, b := range out1[:len(out2)] {
if b != out2[i] {
t.Errorf("mismatch index %d: %02x, expected %02x", i, out2[i], b)
}
mm++
if mm == 10 {
t.Fatal("Stopping")
}
}
}
}
t.Logf("got %d bytes", len(out1))
t.Logf("got %d bytes", len(out1))
})
}

// TestBestSpeed tests that round-tripping through deflate and then inflate
Expand Down

0 comments on commit 59173b5

Please sign in to comment.