Skip to content

Commit

Permalink
fix (extras/crypto): add guarding for Sha256SumWriter/Sha256SumReader (
Browse files Browse the repository at this point in the history
…#169)

Add some mutexes in Sha256SumWriter and Sha256SumReader to ensure
we don't get panics
  • Loading branch information
jimlambrt committed Jun 1, 2023
1 parent 2964a28 commit 08d524b
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions extras/crypto/sha256sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"hash"
"io"
"sync"

wrapping "github.com/hashicorp/go-kms-wrapping/v2"
)
Expand Down Expand Up @@ -44,6 +45,7 @@ func Sha256Sum(ctx context.Context, r io.Reader, opt ...wrapping.Option) ([]byte
// Sha256SumWriter provides multi-writer which will be used to write to a
// hash and produce a sum. It implements io.WriterCloser and io.StringWriter.
type Sha256SumWriter struct {
l sync.Mutex
hash hash.Hash
tee io.Writer
w io.Writer
Expand All @@ -69,6 +71,8 @@ func NewSha256SumWriter(ctx context.Context, w io.Writer) (*Sha256SumWriter, err
// func.
func (w *Sha256SumWriter) Write(b []byte) (int, error) {
const op = "crypto.(Sha256SumWriter).Write"
w.l.Lock()
defer w.l.Unlock()
n, err := w.tee.Write(b)
if err != nil {
return n, fmt.Errorf("%s: %w", op, err)
Expand Down Expand Up @@ -106,6 +110,8 @@ func (w *Sha256SumWriter) Sum(_ context.Context, opt ...wrapping.Option) ([]byte
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
w.l.Lock()
defer w.l.Unlock()
h := w.hash.Sum(nil)
switch {
case opts.WithHexEncoding:
Expand All @@ -119,6 +125,7 @@ func (w *Sha256SumWriter) Sum(_ context.Context, opt ...wrapping.Option) ([]byte
// Sha256SumReader provides an io.Reader which can be used to calculate a sum
// while reading a file. It implements io.ReaderCloser.
type Sha256SumReader struct {
l sync.Mutex
hash hash.Hash
tee io.Reader
r io.Reader
Expand All @@ -142,6 +149,8 @@ func NewSha256SumReader(_ context.Context, r io.Reader) (*Sha256SumReader, error

func (r *Sha256SumReader) Read(b []byte) (int, error) {
const op = "crypto.(Sha256SumReader).Read"
r.l.Lock()
defer r.l.Unlock()
n, err := r.tee.Read(b)
if err != nil {
return n, fmt.Errorf("%s: %w", op, err)
Expand Down Expand Up @@ -170,6 +179,8 @@ func (r *Sha256SumReader) Sum(_ context.Context, opt ...wrapping.Option) ([]byte
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
r.l.Lock()
defer r.l.Unlock()
h := r.hash.Sum(nil)
switch {
case opts.WithHexEncoding:
Expand Down

0 comments on commit 08d524b

Please sign in to comment.