From 86f0c4a3a9d392315cec208f13f4a72d63b75895 Mon Sep 17 00:00:00 2001 From: jonjohnsonjr Date: Wed, 29 Jun 2022 14:22:50 -0700 Subject: [PATCH] Wrap progress updates in a mutex (#1402) Atomically incrementing the number of bytes written isn't sufficient if we're sending the updates out of order. I ran the progress tests 50 times after this change and they passed. --- pkg/v1/remote/multi_write.go | 20 +++---- pkg/v1/remote/progress.go | 69 +++++++++++++++++++++++ pkg/v1/remote/write.go | 103 ++++++++++++----------------------- 3 files changed, 114 insertions(+), 78 deletions(-) create mode 100644 pkg/v1/remote/progress.go diff --git a/pkg/v1/remote/multi_write.go b/pkg/v1/remote/multi_write.go index 7e41d94c4..002ef8587 100644 --- a/pkg/v1/remote/multi_write.go +++ b/pkg/v1/remote/multi_write.go @@ -87,32 +87,32 @@ func MultiWrite(m map[name.Reference]Taggable, options ...Option) (rerr error) { return err } w := writer{ - repo: repo, - client: &http.Client{Transport: tr}, - context: o.context, - updates: o.updates, - lastUpdate: &v1.Update{}, - backoff: o.retryBackoff, - predicate: o.retryPredicate, + repo: repo, + client: &http.Client{Transport: tr}, + context: o.context, + backoff: o.retryBackoff, + predicate: o.retryPredicate, } // Collect the total size of blobs and manifests we're about to write. if o.updates != nil { + w.progress = &progress{updates: o.updates} + w.progress.lastUpdate = &v1.Update{} defer close(o.updates) - defer func() { _ = sendError(o.updates, rerr) }() + defer func() { _ = w.progress.err(rerr) }() for _, b := range blobs { size, err := b.Size() if err != nil { return err } - w.lastUpdate.Total += size + w.progress.total(size) } countManifest := func(t Taggable) error { b, err := t.RawManifest() if err != nil { return err } - w.lastUpdate.Total += int64(len(b)) + w.progress.total(int64(len(b))) return nil } for _, i := range images { diff --git a/pkg/v1/remote/progress.go b/pkg/v1/remote/progress.go new file mode 100644 index 000000000..1f4396350 --- /dev/null +++ b/pkg/v1/remote/progress.go @@ -0,0 +1,69 @@ +// Copyright 2022 Google LLC All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package remote + +import ( + "io" + "sync" + "sync/atomic" + + v1 "github.com/google/go-containerregistry/pkg/v1" +) + +type progress struct { + sync.Mutex + updates chan<- v1.Update + lastUpdate *v1.Update +} + +func (p *progress) total(delta int64) { + atomic.AddInt64(&p.lastUpdate.Total, delta) +} + +func (p *progress) complete(delta int64) { + p.Lock() + defer p.Unlock() + p.updates <- v1.Update{ + Total: p.lastUpdate.Total, + Complete: atomic.AddInt64(&p.lastUpdate.Complete, delta), + } +} + +func (p *progress) err(err error) error { + if err != nil && p.updates != nil { + p.updates <- v1.Update{Error: err} + } + return err +} + +type progressReader struct { + rc io.ReadCloser + + count *int64 // number of bytes this reader has read, to support resetting on retry. + progress *progress +} + +func (r *progressReader) Read(b []byte) (int, error) { + n, err := r.rc.Read(b) + if err != nil { + return n, err + } + atomic.AddInt64(r.count, int64(n)) + // TODO: warn/debug log if sending takes too long, or if sending is blocked while context is canceled. + r.progress.complete(int64(n)) + return n, nil +} + +func (r *progressReader) Close() error { return r.rc.Close() } diff --git a/pkg/v1/remote/write.go b/pkg/v1/remote/write.go index 7b9f31e19..2a68ed87e 100644 --- a/pkg/v1/remote/write.go +++ b/pkg/v1/remote/write.go @@ -23,7 +23,6 @@ import ( "net/http" "net/url" "strings" - "sync/atomic" "github.com/google/go-containerregistry/internal/redact" "github.com/google/go-containerregistry/internal/retry" @@ -49,20 +48,21 @@ func Write(ref name.Reference, img v1.Image, options ...Option) (rerr error) { return err } - var lastUpdate *v1.Update + var p *progress if o.updates != nil { - lastUpdate = &v1.Update{} - lastUpdate.Total, err = countImage(img, o.allowNondistributableArtifacts) + p = &progress{updates: o.updates} + p.lastUpdate = &v1.Update{} + p.lastUpdate.Total, err = countImage(img, o.allowNondistributableArtifacts) if err != nil { return err } defer close(o.updates) - defer func() { _ = sendError(o.updates, rerr) }() + defer func() { _ = p.err(rerr) }() } - return writeImage(o.context, ref, img, o, lastUpdate) + return writeImage(o.context, ref, img, o, p) } -func writeImage(ctx context.Context, ref name.Reference, img v1.Image, o *options, lastUpdate *v1.Update) error { +func writeImage(ctx context.Context, ref name.Reference, img v1.Image, o *options, progress *progress) error { ls, err := img.Layers() if err != nil { return err @@ -73,13 +73,12 @@ func writeImage(ctx context.Context, ref name.Reference, img v1.Image, o *option return err } w := writer{ - repo: ref.Context(), - client: &http.Client{Transport: tr}, - context: ctx, - updates: o.updates, - lastUpdate: lastUpdate, - backoff: o.retryBackoff, - predicate: o.retryPredicate, + repo: ref.Context(), + client: &http.Client{Transport: tr}, + context: ctx, + progress: progress, + backoff: o.retryBackoff, + predicate: o.retryPredicate, } // Upload individual blobs and collect any errors. @@ -174,17 +173,9 @@ type writer struct { client *http.Client context context.Context - updates chan<- v1.Update - lastUpdate *v1.Update - backoff Backoff - predicate retry.Predicate -} - -func sendError(ch chan<- v1.Update, err error) error { - if err != nil && ch != nil { - ch <- v1.Update{Error: err} - } - return err + progress *progress + backoff Backoff + predicate retry.Predicate } // url returns a url.Url for the specified path in the context of this remote image reference. @@ -310,30 +301,6 @@ func (w *writer) initiateUpload(from, mount, origin string) (location string, mo } } -type progressReader struct { - rc io.ReadCloser - - count *int64 // number of bytes this reader has read, to support resetting on retry. - updates chan<- v1.Update - lastUpdate *v1.Update -} - -func (r *progressReader) Read(b []byte) (int, error) { - n, err := r.rc.Read(b) - if err != nil { - return n, err - } - atomic.AddInt64(r.count, int64(n)) - // TODO: warn/debug log if sending takes too long, or if sending is blocked while context is cancelled. - r.updates <- v1.Update{ - Total: r.lastUpdate.Total, - Complete: atomic.AddInt64(&r.lastUpdate.Complete, int64(n)), - } - return n, nil -} - -func (r *progressReader) Close() error { return r.rc.Close() } - // streamBlob streams the contents of the blob to the specified location. // On failure, this will return an error. On success, this will return the location // header indicating how to commit the streamed blob. @@ -350,19 +317,18 @@ func (w *writer) streamBlob(ctx context.Context, layer v1.Layer, streamLocation } getBody := layer.Compressed - if w.updates != nil { + if w.progress != nil { var count int64 - blob = &progressReader{rc: blob, updates: w.updates, lastUpdate: w.lastUpdate, count: &count} + blob = &progressReader{rc: blob, progress: w.progress, count: &count} getBody = func() (io.ReadCloser, error) { blob, err := layer.Compressed() if err != nil { return nil, err } - return &progressReader{rc: blob, updates: w.updates, lastUpdate: w.lastUpdate, count: &count}, nil + return &progressReader{rc: blob, progress: w.progress, count: &count}, nil } reset = func() { - atomic.AddInt64(&w.lastUpdate.Complete, -count) - w.updates <- *w.lastUpdate + w.progress.complete(-count) } } @@ -419,13 +385,10 @@ func (w *writer) commitBlob(location, digest string) error { // incrProgress increments and sends a progress update, if WithProgress is used. func (w *writer) incrProgress(written int64) { - if w.updates == nil { + if w.progress == nil { return } - w.updates <- v1.Update{ - Total: w.lastUpdate.Total, - Complete: atomic.AddInt64(&w.lastUpdate.Complete, written), - } + w.progress.complete(written) } // uploadOne performs a complete upload of a single layer. @@ -546,7 +509,7 @@ func (w *writer) writeIndex(ctx context.Context, ref name.Reference, ii v1.Image if err != nil { return err } - if err := writeImage(ctx, ref, img, o, w.lastUpdate); err != nil { + if err := writeImage(ctx, ref, img, o, w.progress); err != nil { return err } default: @@ -689,19 +652,21 @@ func WriteIndex(ref name.Reference, ii v1.ImageIndex, options ...Option) (rerr e repo: ref.Context(), client: &http.Client{Transport: tr}, context: o.context, - updates: o.updates, backoff: o.retryBackoff, predicate: o.retryPredicate, } if o.updates != nil { - w.lastUpdate = &v1.Update{} - w.lastUpdate.Total, err = countIndex(ii, o.allowNondistributableArtifacts) + w.progress = &progress{updates: o.updates} + w.progress.lastUpdate = &v1.Update{} + + defer close(o.updates) + defer func() { w.progress.err(rerr) }() + + w.progress.lastUpdate.Total, err = countIndex(ii, o.allowNondistributableArtifacts) if err != nil { return err } - defer close(o.updates) - defer func() { sendError(o.updates, rerr) }() } return w.writeIndex(o.context, ref, ii, options...) @@ -830,14 +795,16 @@ func WriteLayer(repo name.Repository, layer v1.Layer, options ...Option) (rerr e repo: repo, client: &http.Client{Transport: tr}, context: o.context, - updates: o.updates, backoff: o.retryBackoff, predicate: o.retryPredicate, } if o.updates != nil { + w.progress = &progress{updates: o.updates} + w.progress.lastUpdate = &v1.Update{} + defer close(o.updates) - defer func() { sendError(o.updates, rerr) }() + defer func() { w.progress.err(rerr) }() // TODO: support streaming layers which update the total count as they write. if _, ok := layer.(*stream.Layer); ok { @@ -847,7 +814,7 @@ func WriteLayer(repo name.Repository, layer v1.Layer, options ...Option) (rerr e if err != nil { return err } - w.lastUpdate = &v1.Update{Total: size} + w.progress.total(size) } return w.uploadOne(o.context, layer) }