Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Add stream content size #401

Merged
merged 1 commit into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion zstd/dict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func TestEncoder_SmallDict(t *testing.T) {
enc := encs[i]
t.Run(encNames[i], func(t *testing.T) {
var buf bytes.Buffer
enc.Reset(&buf)
enc.ResetContentSize(&buf, int64(len(decoded)))
_, err := enc.Write(decoded)
if err != nil {
t.Fatal(err)
Expand Down
4 changes: 2 additions & 2 deletions zstd/enc_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func (e *fastBase) AppendCRC(dst []byte) []byte {

// WindowSize returns the window size of the encoder,
// or a window size small enough to contain the input size, if > 0.
func (e *fastBase) WindowSize(size int) int32 {
if size > 0 && size < int(e.maxMatchOff) {
func (e *fastBase) WindowSize(size int64) int32 {
if size > 0 && size < int64(e.maxMatchOff) {
b := int32(1) << uint(bits.Len(uint(size)))
// Keep minimum window.
if b < 1024 {
Expand Down
31 changes: 27 additions & 4 deletions zstd/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type encoder interface {
Block() *blockEnc
CRC() *xxhash.Digest
AppendCRC([]byte) []byte
WindowSize(size int) int32
WindowSize(size int64) int32
UseBlock(*blockEnc)
Reset(d *dict, singleBlock bool)
}
Expand All @@ -48,6 +48,8 @@ type encoderState struct {
err error
writeErr error
nWritten int64
nInput int64
frameContentSize int64
headerWritten bool
eofWritten bool
fullFrameWritten bool
Expand Down Expand Up @@ -120,7 +122,21 @@ func (e *Encoder) Reset(w io.Writer) {
s.w = w
s.err = nil
s.nWritten = 0
s.nInput = 0
s.writeErr = nil
s.frameContentSize = 0
}

// ResetContentSize will reset and set a content size for the next stream.
// If the bytes written does not match the size given an error will be returned
// when calling Close().
// This is removed when Reset is called.
// Sizes <= 0 results in no content size set.
func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
e.Reset(w)
if size >= 0 {
e.state.frameContentSize = size
}
}

// Write data to the encoder.
Expand Down Expand Up @@ -190,6 +206,7 @@ func (e *Encoder) nextBlock(final bool) error {
return s.err
}
s.nWritten += int64(n2)
s.nInput += int64(len(s.filling))
s.current = s.current[:0]
s.filling = s.filling[:0]
s.headerWritten = true
Expand All @@ -200,8 +217,8 @@ func (e *Encoder) nextBlock(final bool) error {

var tmp [maxHeaderSize]byte
fh := frameHeader{
ContentSize: 0,
WindowSize: uint32(s.encoder.WindowSize(0)),
ContentSize: uint64(s.frameContentSize),
WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
SingleSegment: false,
Checksum: e.o.crc,
DictID: e.o.dict.ID(),
Expand Down Expand Up @@ -243,6 +260,7 @@ func (e *Encoder) nextBlock(final bool) error {

// Move blocks forward.
s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
s.nInput += int64(len(s.current))
s.wg.Add(1)
go func(src []byte) {
if debugEncoder {
Expand Down Expand Up @@ -394,6 +412,11 @@ func (e *Encoder) Close() error {
if err != nil {
return err
}
if s.frameContentSize > 0 {
if s.nInput != s.frameContentSize {
return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
}
}
if e.state.fullFrameWritten {
return s.err
}
Expand Down Expand Up @@ -470,7 +493,7 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte {
}
fh := frameHeader{
ContentSize: uint64(len(src)),
WindowSize: uint32(enc.WindowSize(len(src))),
WindowSize: uint32(enc.WindowSize(int64(len(src)))),
SingleSegment: single,
Checksum: e.o.crc,
DictID: e.o.dict.ID(),
Expand Down
4 changes: 2 additions & 2 deletions zstd/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func TestEncoderRegression(t *testing.T) {

// Use the Writer
var dst bytes.Buffer
enc.Reset(&dst)
enc.ResetContentSize(&dst, int64(len(in)))
_, err = enc.Write(in)
if err != nil {
t.Error(err)
Expand Down Expand Up @@ -407,7 +407,7 @@ func TestWithEncoderPadding(t *testing.T) {

// Test using the writer.
var buf bytes.Buffer
e.Reset(&buf)
e.ResetContentSize(&buf, int64(len(src)))
_, err = io.Copy(e, bytes.NewBuffer(src))
if err != nil {
t.Fatal(err)
Expand Down