diff --git a/internal/io/copy.go b/internal/io/copy.go index e22dbdb030..7ea5e3278d 100644 --- a/internal/io/copy.go +++ b/internal/io/copy.go @@ -22,6 +22,7 @@ import ( "io" "math" "sync/atomic" + "syscall" "github.com/vdaas/vald/internal/errors" "github.com/vdaas/vald/internal/sync" @@ -33,8 +34,13 @@ func Copy(dst io.Writer, src io.Reader) (written int64, err error) { return cio.Copy(dst, src) } +func CopyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { + return cio.CopyBuffer(dst, src, buf) +} + type Copier interface { Copy(dst io.Writer, src io.Reader) (written int64, err error) + CopyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) } type copier struct { @@ -42,9 +48,7 @@ type copier struct { pool sync.Pool } -const ( - defaultBufferSize int = 64 * 1024 -) +var defaultBufferSize int = 16 * syscall.Getpagesize() func NewCopier(size int) Copier { c := new(copier) @@ -62,6 +66,22 @@ func NewCopier(size int) Copier { } func (c *copier) Copy(dst io.Writer, src io.Reader) (written int64, err error) { + return c.copyBuffer(dst, src, nil) +} + +func (c *copier) CopyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { + if buf == nil { + return c.Copy(dst, src) + } + if buf != nil && len(buf) == 0 { + panic("empty buffer in CopyBuffer") + } + b := bytes.NewBuffer(buf) + defer b.Reset() + return c.copyBuffer(dst, src, b) +} + +func (c *copier) copyBuffer(dst io.Writer, src io.Reader, buf *bytes.Buffer) (written int64, err error) { if dst == nil || src == nil { return 0, errors.New("empty source or destination") } @@ -81,24 +101,25 @@ func (c *copier) Copy(dst io.Writer, src io.Reader) (written int64, err error) { limit int64 = math.MaxInt64 size int64 = atomic.LoadInt64(&c.bufSize) l *io.LimitedReader - buf *bytes.Buffer ) - if l, ok = src.(*io.LimitedReader); ok && l.N >= 1 && size > l.N { - limit = l.N - size = limit - } - buf, ok = c.pool.Get().(*bytes.Buffer) - if !ok || buf == nil { - buf = bytes.NewBuffer(make([]byte, size)) - } - defer func() { - if atomic.LoadInt64(&c.bufSize) < size { - atomic.StoreInt64(&c.bufSize, size) - buf.Grow(int(size)) + if buf == nil { + if l, ok = src.(*io.LimitedReader); ok && l.N >= 1 && size > l.N { + limit = l.N + size = limit } - buf.Reset() - c.pool.Put(buf) - }() + buf, ok = c.pool.Get().(*bytes.Buffer) + if !ok || buf == nil { + buf = bytes.NewBuffer(make([]byte, size)) + } + defer func() { + if atomic.LoadInt64(&c.bufSize) < size { + atomic.StoreInt64(&c.bufSize, size) + buf.Grow(int(size)) + } + buf.Reset() + c.pool.Put(buf) + }() + } if size > int64(buf.Cap()) { size = int64(buf.Cap()) } diff --git a/internal/io/copy_bench_test.go b/internal/io/copy_bench_test.go index 3dd206e5f4..ad33d00c17 100644 --- a/internal/io/copy_bench_test.go +++ b/internal/io/copy_bench_test.go @@ -82,6 +82,15 @@ func BenchmarkValdIOCopy(b *testing.B) { } } +func BenchmarkValdIOCopyBuffer(b *testing.B) { + c := NewCopier(bufferLength) + for i := 0; i < b.N; i++ { + w := &writer{} + r := &reader{len: readerLength} + c.CopyBuffer(w, r, nil) + } +} + func BenchmarkStandardIOCopyParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -113,3 +122,14 @@ func BenchmarkValdIOCopyParallel(b *testing.B) { } }) } + +func BenchmarkValdIOCopyBufferParallel(b *testing.B) { + c := NewCopier(bufferLength) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + w := &writer{} + r := &reader{len: readerLength} + c.CopyBuffer(w, r, nil) + } + }) +}