diff --git a/zstd/dict_test.go b/zstd/dict_test.go index e5cf6ac983..34201da8fc 100644 --- a/zstd/dict_test.go +++ b/zstd/dict_test.go @@ -184,7 +184,7 @@ func TestEncoder_SmallDict(t *testing.T) { enc := encs[i] t.Run(encNames[i], func(t *testing.T) { var buf bytes.Buffer - enc.Reset(&buf) + enc.ResetContentSize(&buf, int64(len(decoded))) _, err := enc.Write(decoded) if err != nil { t.Fatal(err) diff --git a/zstd/enc_base.go b/zstd/enc_base.go index 60f2986486..295cd602a4 100644 --- a/zstd/enc_base.go +++ b/zstd/enc_base.go @@ -38,8 +38,8 @@ func (e *fastBase) AppendCRC(dst []byte) []byte { // WindowSize returns the window size of the encoder, // or a window size small enough to contain the input size, if > 0. -func (e *fastBase) WindowSize(size int) int32 { - if size > 0 && size < int(e.maxMatchOff) { +func (e *fastBase) WindowSize(size int64) int32 { + if size > 0 && size < int64(e.maxMatchOff) { b := int32(1) << uint(bits.Len(uint(size))) // Keep minimum window. if b < 1024 { diff --git a/zstd/encoder.go b/zstd/encoder.go index ea85548fc9..e6e315969b 100644 --- a/zstd/encoder.go +++ b/zstd/encoder.go @@ -33,7 +33,7 @@ type encoder interface { Block() *blockEnc CRC() *xxhash.Digest AppendCRC([]byte) []byte - WindowSize(size int) int32 + WindowSize(size int64) int32 UseBlock(*blockEnc) Reset(d *dict, singleBlock bool) } @@ -48,6 +48,8 @@ type encoderState struct { err error writeErr error nWritten int64 + nInput int64 + frameContentSize int64 headerWritten bool eofWritten bool fullFrameWritten bool @@ -120,7 +122,21 @@ func (e *Encoder) Reset(w io.Writer) { s.w = w s.err = nil s.nWritten = 0 + s.nInput = 0 s.writeErr = nil + s.frameContentSize = 0 +} + +// ResetContentSize will reset and set a content size for the next stream. +// If the bytes written does not match the size given an error will be returned +// when calling Close(). +// This is removed when Reset is called. +// Sizes <= 0 results in no content size set. +func (e *Encoder) ResetContentSize(w io.Writer, size int64) { + e.Reset(w) + if size >= 0 { + e.state.frameContentSize = size + } } // Write data to the encoder. @@ -190,6 +206,7 @@ func (e *Encoder) nextBlock(final bool) error { return s.err } s.nWritten += int64(n2) + s.nInput += int64(len(s.filling)) s.current = s.current[:0] s.filling = s.filling[:0] s.headerWritten = true @@ -200,8 +217,8 @@ func (e *Encoder) nextBlock(final bool) error { var tmp [maxHeaderSize]byte fh := frameHeader{ - ContentSize: 0, - WindowSize: uint32(s.encoder.WindowSize(0)), + ContentSize: uint64(s.frameContentSize), + WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)), SingleSegment: false, Checksum: e.o.crc, DictID: e.o.dict.ID(), @@ -243,6 +260,7 @@ func (e *Encoder) nextBlock(final bool) error { // Move blocks forward. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current + s.nInput += int64(len(s.current)) s.wg.Add(1) go func(src []byte) { if debugEncoder { @@ -394,6 +412,11 @@ func (e *Encoder) Close() error { if err != nil { return err } + if s.frameContentSize > 0 { + if s.nInput != s.frameContentSize { + return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput) + } + } if e.state.fullFrameWritten { return s.err } @@ -470,7 +493,7 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte { } fh := frameHeader{ ContentSize: uint64(len(src)), - WindowSize: uint32(enc.WindowSize(len(src))), + WindowSize: uint32(enc.WindowSize(int64(len(src)))), SingleSegment: single, Checksum: e.o.crc, DictID: e.o.dict.ID(), diff --git a/zstd/encoder_test.go b/zstd/encoder_test.go index 2ac99a1dbc..b5ddf9fe57 100644 --- a/zstd/encoder_test.go +++ b/zstd/encoder_test.go @@ -252,7 +252,7 @@ func TestEncoderRegression(t *testing.T) { // Use the Writer var dst bytes.Buffer - enc.Reset(&dst) + enc.ResetContentSize(&dst, int64(len(in))) _, err = enc.Write(in) if err != nil { t.Error(err) @@ -407,7 +407,7 @@ func TestWithEncoderPadding(t *testing.T) { // Test using the writer. var buf bytes.Buffer - e.Reset(&buf) + e.ResetContentSize(&buf, int64(len(src))) _, err = io.Copy(e, bytes.NewBuffer(src)) if err != nil { t.Fatal(err)