diff --git a/s2/encode.go b/s2/encode.go index 53431d99ee..cdb3ab45b2 100644 --- a/s2/encode.go +++ b/s2/encode.go @@ -238,6 +238,7 @@ func NewWriter(w io.Writer, opts ...WriterOption) *Writer { w2 := Writer{ blockSize: defaultBlockSize, concurrency: runtime.GOMAXPROCS(0), + randSrc: rand.Reader, } for _, opt := range opts { if err := opt(&w2); err != nil { @@ -272,12 +273,14 @@ type Writer struct { pad int writer io.Writer + randSrc io.Reader writerWg sync.WaitGroup // wroteStreamHeader is whether we have written the stream header. wroteStreamHeader bool paramsOK bool better bool + uncompressed bool } type result []byte @@ -482,7 +485,7 @@ func (w *Writer) EncodeBuffer(buf []byte) (err error) { var n2 int if w.better { n2 = encodeBlockBetter(obuf[obufHeaderLen+n:], uncompressed) - } else { + } else if !w.uncompressed { n2 = encodeBlock(obuf[obufHeaderLen+n:], uncompressed) } @@ -559,7 +562,7 @@ func (w *Writer) write(p []byte) (nRet int, errRet error) { var n2 int if w.better { n2 = encodeBlockBetter(obuf[obufHeaderLen+n:], uncompressed) - } else { + } else if !w.uncompressed { n2 = encodeBlock(obuf[obufHeaderLen+n:], uncompressed) } @@ -635,7 +638,7 @@ func (w *Writer) writeFull(inbuf []byte) (errRet error) { var n2 int if w.better { n2 = encodeBlockBetter(obuf[obufHeaderLen+n:], uncompressed) - } else { + } else if !w.uncompressed { n2 = encodeBlock(obuf[obufHeaderLen+n:], uncompressed) } @@ -704,7 +707,7 @@ func (w *Writer) writeSync(p []byte) (nRet int, errRet error) { var n2 int if w.better { n2 = encodeBlockBetter(obuf[obufHeaderLen+n:], uncompressed) - } else { + } else if !w.uncompressed { n2 = encodeBlock(obuf[obufHeaderLen+n:], uncompressed) } @@ -793,7 +796,7 @@ func (w *Writer) Close() error { } if w.err(nil) == nil && w.writer != nil && w.pad > 0 { add := calcSkippableFrame(w.written, int64(w.pad)) - frame, err := skippableFrame(w.ibuf[:0], add, rand.Reader) + frame, err := skippableFrame(w.ibuf[:0], add, w.randSrc) if err = w.err(err); err != nil { return err } @@ -877,11 +880,23 @@ func WriterConcurrency(n int) WriterOption { // 10-40% speed decrease on both compression and decompression. func WriterBetterCompression() WriterOption { return func(w *Writer) error { + w.uncompressed = false w.better = true return nil } } +// WriterUncompressed will bypass compression. +// The stream will be written as uncompressed blocks only. +// If concurrency is > 1 CRC and output will still be done async. +func WriterUncompressed() WriterOption { + return func(w *Writer) error { + w.better = false + w.uncompressed = true + return nil + } +} + // WriterBlockSize allows to override the default block size. // Blocks will be this size or smaller. // Minimum size is 4KB and and maximum size is 4MB. @@ -922,3 +937,12 @@ func WriterPadding(n int) WriterOption { return nil } } + +// WriterPaddingSrc will get random data for padding from the supplied source. +// By default crypto/rand is used. +func WriterPaddingSrc(reader io.Reader) WriterOption { + return func(w *Writer) error { + w.randSrc = reader + return nil + } +} diff --git a/s2/encode_test.go b/s2/encode_test.go index 34b9dc7775..2d83095d0d 100644 --- a/s2/encode_test.go +++ b/s2/encode_test.go @@ -16,6 +16,47 @@ import ( "github.com/klauspost/compress/zip" ) +func testOptions(t *testing.T) map[string][]WriterOption { + var testOptions = map[string][]WriterOption{ + "default": {}, + "better": {WriterBetterCompression()}, + "none": {WriterUncompressed()}, + } + + x := make(map[string][]WriterOption) + cloneAdd := func(org []WriterOption, add ...WriterOption) []WriterOption { + y := make([]WriterOption, len(org)+len(add)) + copy(y, org) + copy(y[len(org):], add) + return y + } + for name, opt := range testOptions { + x[name] = opt + x[name+"-c1"] = cloneAdd(opt, WriterConcurrency(1)) + } + testOptions = x + x = make(map[string][]WriterOption) + for name, opt := range testOptions { + x[name] = opt + if !testing.Short() { + x[name+"-1k-win"] = cloneAdd(opt, WriterBlockSize(1<<10)) + x[name+"-4M-win"] = cloneAdd(opt, WriterBlockSize(4<<20)) + } + } + testOptions = x + x = make(map[string][]WriterOption) + for name, opt := range testOptions { + x[name] = opt + x[name+"-pad-min"] = cloneAdd(opt, WriterPadding(2), WriterPaddingSrc(rand.New(rand.NewSource(0)))) + if !testing.Short() { + x[name+"-pad-8000"] = cloneAdd(opt, WriterPadding(8000), WriterPaddingSrc(rand.New(rand.NewSource(0)))) + x[name+"-pad-max"] = cloneAdd(opt, WriterPadding(4<<20), WriterPaddingSrc(rand.New(rand.NewSource(0)))) + } + } + testOptions = x + return testOptions +} + func TestEncoderRegression(t *testing.T) { data, err := ioutil.ReadFile("testdata/enc_regressions.zip") if err != nil { @@ -27,107 +68,112 @@ func TestEncoderRegression(t *testing.T) { } // Same as fuzz test... test := func(t *testing.T, data []byte) { - dec := NewReader(nil) - enc := NewWriter(nil, WriterConcurrency(2), WriterPadding(255), WriterBlockSize(128<<10)) - encBetter := NewWriter(nil, WriterConcurrency(2), WriterPadding(255), WriterBetterCompression(), WriterBlockSize(512<<10)) + for name, opts := range testOptions(t) { + t.Run(name, func(t *testing.T) { + dec := NewReader(nil) + enc := NewWriter(nil, opts...) - comp := Encode(make([]byte, MaxEncodedLen(len(data))), data) - decoded, err := Decode(nil, comp) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(data, decoded) { - t.Error("block decoder mismatch") - return - } - if mel := MaxEncodedLen(len(data)); len(comp) > mel { - t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp))) - return - } - comp = EncodeBetter(make([]byte, MaxEncodedLen(len(data))), data) - decoded, err = Decode(nil, comp) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(data, decoded) { - t.Error("block decoder mismatch") - return - } - if mel := MaxEncodedLen(len(data)); len(comp) > mel { - t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp))) - return - } - // Test writer and use "better": - var buf bytes.Buffer - encBetter.Reset(&buf) - n, err := encBetter.Write(data) - if err != nil { - t.Error(err) - return - } - if n != len(data) { - t.Error(fmt.Errorf("Write: Short write, want %d, got %d", len(data), n)) - return - } - err = encBetter.Close() - if err != nil { - t.Error(err) - return - } - // Calling close twice should not affect anything. - err = encBetter.Close() - if err != nil { - t.Error(err) - return - } - comp = buf.Bytes() - if len(comp)%255 != 0 { - t.Error(fmt.Errorf("wanted size to be mutiple of %d, got size %d with remainder %d", 255, len(comp), len(comp)%255)) - return - } - dec.Reset(&buf) - got, err := ioutil.ReadAll(dec) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(data, got) { - t.Error("block (reset) decoder mismatch") - return - } - // Test Reset on both and use ReadFrom instead. - input := bytes.NewBuffer(data) - buf = bytes.Buffer{} - enc.Reset(&buf) - n2, err := enc.ReadFrom(input) - if err != nil { - t.Error(err) - return - } - if n2 != int64(len(data)) { - t.Error(fmt.Errorf("ReadFrom: Short read, want %d, got %d", len(data), n2)) - return - } - err = enc.Close() - if err != nil { - t.Error(err) - return - } - if buf.Len()%255 != 0 { - t.Error(fmt.Errorf("wanted size to be mutiple of %d, got size %d with remainder %d", 255, buf.Len(), buf.Len()%255)) - return - } - dec.Reset(&buf) - got, err = ioutil.ReadAll(dec) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(data, got) { - t.Error("frame (reset) decoder mismatch") - return + comp := Encode(make([]byte, MaxEncodedLen(len(data))), data) + decoded, err := Decode(nil, comp) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, decoded) { + t.Error("block decoder mismatch") + return + } + if mel := MaxEncodedLen(len(data)); len(comp) > mel { + t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp))) + return + } + comp = EncodeBetter(make([]byte, MaxEncodedLen(len(data))), data) + decoded, err = Decode(nil, comp) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, decoded) { + t.Error("block decoder mismatch") + return + } + if mel := MaxEncodedLen(len(data)); len(comp) > mel { + t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp))) + return + } + + // Test writer. + var buf bytes.Buffer + enc.Reset(&buf) + n, err := enc.Write(data) + if err != nil { + t.Error(err) + return + } + if n != len(data) { + t.Error(fmt.Errorf("Write: Short write, want %d, got %d", len(data), n)) + return + } + err = enc.Close() + if err != nil { + t.Error(err) + return + } + // Calling close twice should not affect anything. + err = enc.Close() + if err != nil { + t.Error(err) + return + } + comp = buf.Bytes() + if enc.pad > 0 && len(comp)%enc.pad != 0 { + t.Error(fmt.Errorf("wanted size to be mutiple of %d, got size %d with remainder %d", enc.pad, len(comp), len(comp)%enc.pad)) + return + } + dec.Reset(&buf) + got, err := ioutil.ReadAll(dec) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, got) { + t.Error("block (reset) decoder mismatch") + return + } + + // Test Reset on both and use ReadFrom instead. + input := bytes.NewBuffer(data) + buf = bytes.Buffer{} + enc.Reset(&buf) + n2, err := enc.ReadFrom(input) + if err != nil { + t.Error(err) + return + } + if n2 != int64(len(data)) { + t.Error(fmt.Errorf("ReadFrom: Short read, want %d, got %d", len(data), n2)) + return + } + err = enc.Close() + if err != nil { + t.Error(err) + return + } + if enc.pad > 0 && buf.Len()%enc.pad != 0 { + t.Error(fmt.Errorf("wanted size to be mutiple of %d, got size %d with remainder %d", enc.pad, buf.Len(), buf.Len()%enc.pad)) + return + } + dec.Reset(&buf) + got, err = ioutil.ReadAll(dec) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, got) { + t.Error("frame (reset) decoder mismatch") + return + } + }) } } for _, tt := range zr.File {