Skip to content

Commit

Permalink
remove context from struct and pass it in when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
kmulvey committed Apr 13, 2023
1 parent caf5fff commit be6b7ae
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 28 deletions.
4 changes: 2 additions & 2 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ func main() {
os.Exit(1)
}

var ch = concurrenthash.NewConcurrentHash(ctx, threads, blockSize, argToHashFuncMap[hashFunc])
var hash, err = ch.HashFile(file)
var ch = concurrenthash.NewConcurrentHash(threads, blockSize, argToHashFuncMap[hashFunc])
var hash, err = ch.HashFile(ctx, file)
if err != nil {
fmt.Printf("Encountered an error: %s", err.Error())
return
Expand Down
6 changes: 4 additions & 2 deletions collect.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package concurrenthash

import "context"

// collectSums is a fan in func to get the hashes and write them to an array
func (c *ConcurrentHash) collectSums(sums <-chan sum) {
func (c *ConcurrentHash) collectSums(ctx context.Context, sums <-chan sum) {
for {
select {
case <-c.Context.Done():
case <-ctx.Done():
return
default:
select {
Expand Down
4 changes: 2 additions & 2 deletions collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ func TestCollectSums(t *testing.T) {
t.Parallel()

var ctx, cancel = context.WithCancel(context.Background())
var cs = NewConcurrentHash(ctx, 2, 10, sha256.New)
var cs = NewConcurrentHash(2, 10, sha256.New)
cs.Hashes = make([][]byte, 2)
var sums = make(chan sum)
go cs.collectSums(sums)
go cs.collectSums(ctx, sums)

sums <- sum{
Index: 1,
Expand Down
12 changes: 5 additions & 7 deletions concurrenthash.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,22 @@ type ConcurrentHash struct {
HashConstructor func() hash.Hash

// internal
Context context.Context
Hashes [][]byte
HashesLock sync.RWMutex
}

// NewConcurrentHash is the constructor and entrypoint
func NewConcurrentHash(ctx context.Context, concurrency int, blockSize int64, hashFunc func() hash.Hash) ConcurrentHash {
func NewConcurrentHash(concurrency int, blockSize int64, hashFunc func() hash.Hash) ConcurrentHash {
return ConcurrentHash{
Concurrency: concurrency,
BlockSize: blockSize,
HashConstructor: hashFunc,
Context: ctx,
}
}

// HashFile is a coordination func that fans out to hash workers,
// collects their output and hashes the final array
func (c *ConcurrentHash) HashFile(file string) (string, error) {
func (c *ConcurrentHash) HashFile(ctx context.Context, file string) (string, error) {

// make sure the file even exists first
var stat, err = os.Stat(file)
Expand All @@ -66,11 +64,11 @@ func (c *ConcurrentHash) HashFile(file string) (string, error) {
var sums = make(chan sum)
sumChans[i] = sums
errGroup.Go(func() error {
return c.hashBlock(blocks, sums)
return c.hashBlock(ctx, blocks, sums)
})
}
errGroup.Go(func() error {
c.collectSums(goutils.MergeChannels(sumChans...))
c.collectSums(ctx, goutils.MergeChannels(sumChans...))
return nil
})
errGroup.Go(func() error {
Expand All @@ -94,7 +92,7 @@ func (c *ConcurrentHash) HashFile(file string) (string, error) {

var h = c.HashConstructor()
h.Reset()
h.Write(buf.Bytes())
_, err = h.Write(buf.Bytes())
if err != nil {
return "", err
}
Expand Down
7 changes: 4 additions & 3 deletions concurrenthash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import (
func TestEverything(t *testing.T) {
t.Parallel()

var cs = NewConcurrentHash(context.Background(), 2, 10, sha256.New)
var sum, err = cs.HashFile("./rand-file.txt")
var ctx = context.Background()
var cs = NewConcurrentHash(2, 10, sha256.New)
var sum, err = cs.HashFile(ctx, "./rand-file.txt")
assert.NoError(t, err)
assert.Equal(t, "bf842e96b246556052bc7e518de1fdf7c4a5a859ad104a201880074bece30b82", sum)

sum, err = cs.HashFile("./sdfsdfsf.txt")
sum, err = cs.HashFile(ctx, "./sdfsdfsf.txt")
assert.Equal(t, "stat ./sdfsdfsf.txt: no such file or directory", err.Error())
assert.Equal(t, "", sum)
}
5 changes: 2 additions & 3 deletions file_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package concurrenthash

import (
"context"
"crypto/sha256"
"testing"

Expand All @@ -25,7 +24,7 @@ func TestReadFile(t *testing.T) {
t.Parallel()

var blocks = make(chan block)
var cs = NewConcurrentHash(context.Background(), 1, 10, sha256.New)
var cs = NewConcurrentHash(1, 10, sha256.New)

var done = make(chan struct{})
go func() {
Expand All @@ -45,7 +44,7 @@ func TestReadFileNotExist(t *testing.T) {
t.Parallel()

var blocks = make(chan block)
var cs = NewConcurrentHash(context.Background(), 1, 10, sha256.New)
var cs = NewConcurrentHash(1, 10, sha256.New)

var done = make(chan struct{})
go func() {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.19
require (
github.com/kmulvey/goutils v0.6.0
github.com/stretchr/testify v1.8.1
github.com/twmb/murmur3 v1.1.6
github.com/twmb/murmur3 v1.1.7
golang.org/x/sync v0.1.0
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg=
github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ=
github.com/twmb/murmur3 v1.1.7 h1:ULWBiM04n/XoN3YMSJ6Z2pHDFLf+MeIVQU71ZPrvbWg=
github.com/twmb/murmur3 v1.1.7/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
Expand Down
6 changes: 4 additions & 2 deletions hash.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package concurrenthash

import "context"

// hashBlock runs the hash func on each block of bytes
func (c *ConcurrentHash) hashBlock(blocks <-chan block, sums chan<- sum) error {
func (c *ConcurrentHash) hashBlock(ctx context.Context, blocks <-chan block, sums chan<- sum) error {
defer close(sums)
var h = c.HashConstructor()
for {
select {
case <-c.Context.Done():
case <-ctx.Done():
return nil
default:
select {
Expand Down
4 changes: 2 additions & 2 deletions hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestHash(t *testing.T) {
var blocks = make(chan block)
var sums = make(chan sum)
var ctx, cancel = context.WithCancel(context.Background())
var cs = ConcurrentHash{Context: ctx, HashConstructor: sha256.New}
var cs = ConcurrentHash{HashConstructor: sha256.New}

go func() {
for sum := range sums {
Expand All @@ -26,7 +26,7 @@ func TestHash(t *testing.T) {
}()

go func() {
assert.NoError(t, cs.hashBlock(blocks, sums))
assert.NoError(t, cs.hashBlock(ctx, blocks, sums))
}()

blocks <- block{Index: 0, Data: []byte{0x32, 0x96, 0xd0, 0x3, 0x2b, 0x56, 0x72, 0x2b, 0xaf, 0x39}}
Expand Down
4 changes: 2 additions & 2 deletions wrappers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func TestWrappers(t *testing.T) {
t.Parallel()

for _, pair := range testMatrix {
var cs = NewConcurrentHash(context.Background(), 2, 10, pair.HashFunc)
var sum, err = cs.HashFile("./rand-file.txt")
var cs = NewConcurrentHash(2, 10, pair.HashFunc)
var sum, err = cs.HashFile(context.Background(), "./rand-file.txt")
assert.NoError(t, err)
assert.Equal(t, pair.Expected, sum)
}
Expand Down

0 comments on commit be6b7ae

Please sign in to comment.