diff --git a/zstd.go b/zstd.go index 80507e14e..f28db806e 100644 --- a/zstd.go +++ b/zstd.go @@ -6,28 +6,49 @@ import ( "github.com/klauspost/compress/zstd" ) +// zstdMaxBufferedEncoders maximum number of not-in-use zstd encoders +// If the pool of encoders is exhausted then new encoders will be created on the fly +const zstdMaxBufferedEncoders = 1 + type ZstdEncoderParams struct { Level int } type ZstdDecoderParams struct { } -var zstdEncMap, zstdDecMap sync.Map +var zstdDecMap sync.Map + +var zstdAvailableEncoders sync.Map -func getEncoder(params ZstdEncoderParams) *zstd.Encoder { - if ret, ok := zstdEncMap.Load(params); ok { - return ret.(*zstd.Encoder) +func getZstdEncoderChannel(params ZstdEncoderParams) chan *zstd.Encoder { + if c, ok := zstdAvailableEncoders.Load(params); ok { + return c.(chan *zstd.Encoder) } - // It's possible to race and create multiple new writers. - // Only one will survive GC after use. - encoderLevel := zstd.SpeedDefault - if params.Level != CompressionLevelDefault { - encoderLevel = zstd.EncoderLevelFromZstd(params.Level) + c, _ := zstdAvailableEncoders.LoadOrStore(params, make(chan *zstd.Encoder, zstdMaxBufferedEncoders)) + return c.(chan *zstd.Encoder) +} + +func getZstdEncoder(params ZstdEncoderParams) *zstd.Encoder { + select { + case enc := <-getZstdEncoderChannel(params): + return enc + default: + encoderLevel := zstd.SpeedDefault + if params.Level != CompressionLevelDefault { + encoderLevel = zstd.EncoderLevelFromZstd(params.Level) + } + zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(encoderLevel), + zstd.WithEncoderConcurrency(1)) + return zstdEnc + } +} + +func releaseEncoder(params ZstdEncoderParams, enc *zstd.Encoder) { + select { + case getZstdEncoderChannel(params) <- enc: + default: } - zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), - zstd.WithEncoderLevel(encoderLevel)) - zstdEncMap.Store(params, zstdEnc) - return zstdEnc } func getDecoder(params ZstdDecoderParams) *zstd.Decoder { @@ -46,5 +67,8 @@ func zstdDecompress(params ZstdDecoderParams, dst, src []byte) ([]byte, error) { } func zstdCompress(params ZstdEncoderParams, dst, src []byte) ([]byte, error) { - return getEncoder(params).EncodeAll(src, dst), nil + enc := getZstdEncoder(params) + out := enc.EncodeAll(src, dst) + releaseEncoder(params, enc) + return out, nil } diff --git a/zstd_test.go b/zstd_test.go new file mode 100644 index 000000000..efdc6d83d --- /dev/null +++ b/zstd_test.go @@ -0,0 +1,29 @@ +package sarama + +import ( + "runtime" + "testing" +) + +func BenchmarkZstdMemoryConsumption(b *testing.B) { + params := ZstdEncoderParams{Level: 9} + buf := make([]byte, 1024*1024) + for i := 0; i < len(buf); i++ { + buf[i] = byte((i / 256) + (i * 257)) + } + + cpus := 96 + + gomaxprocsBackup := runtime.GOMAXPROCS(cpus) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for j := 0; j < 2*cpus; j++ { + _, _ = zstdCompress(params, nil, buf) + } + // drain the buffered encoder + getZstdEncoder(params) + // previously this would be achieved with + // zstdEncMap.Delete(params) + } + runtime.GOMAXPROCS(gomaxprocsBackup) +}