diff --git a/fetch.go b/fetch.go index c8a11bf..bdc07cf 100644 --- a/fetch.go +++ b/fetch.go @@ -24,7 +24,7 @@ type fetchProtocol struct { host host.Host } -type getValue func(key string) ([]byte, error) +type getValue func(ctx context.Context, key string) ([]byte, error) func newFetchProtocol(ctx context.Context, host host.Host, getData getValue) *fetchProtocol { p := &fetchProtocol{ctx, host} @@ -46,7 +46,7 @@ func (p *fetchProtocol) receive(s network.Stream, getData getValue) { return } - response, err := getData(msg.Identifier) + response, err := getData(p.ctx, msg.Identifier) var respProto pb.FetchResponse if err != nil { diff --git a/fetch_test.go b/fetch_test.go index 7db6cbe..ac2e305 100644 --- a/fetch_test.go +++ b/fetch_test.go @@ -22,7 +22,7 @@ type datastore struct { data map[string][]byte } -func (d *datastore) Lookup(key string) ([]byte, error) { +func (d *datastore) Lookup(ctx context.Context, key string) ([]byte, error) { v, ok := d.data[key] if !ok { return nil, errors.New("key not found") diff --git a/pubsub.go b/pubsub.go index 50fe821..c5976f1 100644 --- a/pubsub.go +++ b/pubsub.go @@ -127,7 +127,7 @@ func (p *PubsubValueStore) PutValue(ctx context.Context, key string, value []byt ti.dbWriteMx.Lock() defer ti.dbWriteMx.Unlock() - recCmp, err := p.putLocal(ti, key, value) + recCmp, err := p.putLocal(ctx, ti, key, value) if err != nil { return err } @@ -147,12 +147,12 @@ func (p *PubsubValueStore) PutValue(ctx context.Context, key string, value []byt // First return value is 0 if equal, greater than 0 if better, less than 0 if worse. // Second return value is true if valid. // -func (p *PubsubValueStore) compare(key string, val []byte) (int, bool) { +func (p *PubsubValueStore) compare(ctx context.Context, key string, val []byte) (int, bool) { if p.Validator.Validate(key, val) != nil { return -1, false } - old, err := p.getLocal(key) + old, err := p.getLocal(ctx, key) if err != nil { // If the old one is invalid, the new one is *always* better. return 1, true @@ -192,7 +192,7 @@ func (p *PubsubValueStore) Subscribe(key string) error { src peer.ID, msg *pubsub.Message, ) pubsub.ValidationResult { - cmp, valid := p.compare(key, msg.GetData()) + cmp, valid := p.compare(ctx, key, msg.GetData()) if !valid { return pubsub.ValidationReject } @@ -271,7 +271,7 @@ func (p *PubsubValueStore) rebroadcast(ctx context.Context) { p.mx.Unlock() if len(topics) > 0 { for i, k := range keys { - val, err := p.getLocal(k) + val, err := p.getLocal(ctx, k) if err == nil { topic := topics[i].topic select { @@ -300,16 +300,16 @@ func (p *PubsubValueStore) psPublishChannel(ctx context.Context, topic *pubsub.T // Requires that the ti.dbWriteMx is held when called // Returns true if the value is better then what is currently in the datastore // Returns any errors from putting the data in the datastore -func (p *PubsubValueStore) putLocal(ti *topicInfo, key string, value []byte) (int, error) { - cmp, valid := p.compare(key, value) +func (p *PubsubValueStore) putLocal(ctx context.Context, ti *topicInfo, key string, value []byte) (int, error) { + cmp, valid := p.compare(ctx, key, value) if valid && cmp > 0 { - return cmp, p.ds.Put(dshelp.NewKeyFromBinary([]byte(key)), value) + return cmp, p.ds.Put(ctx, dshelp.NewKeyFromBinary([]byte(key)), value) } return cmp, nil } -func (p *PubsubValueStore) getLocal(key string) ([]byte, error) { - val, err := p.ds.Get(dshelp.NewKeyFromBinary([]byte(key))) +func (p *PubsubValueStore) getLocal(ctx context.Context, key string) ([]byte, error) { + val, err := p.ds.Get(ctx, dshelp.NewKeyFromBinary([]byte(key))) if err != nil { // Don't invalidate due to ds errors. if err == ds.ErrNotFound { @@ -330,7 +330,7 @@ func (p *PubsubValueStore) GetValue(ctx context.Context, key string, opts ...rou return nil, err } - return p.getLocal(key) + return p.getLocal(ctx, key) } func (p *PubsubValueStore) SearchValue(ctx context.Context, key string, opts ...routing.Option) (<-chan []byte, error) { @@ -342,7 +342,7 @@ func (p *PubsubValueStore) SearchValue(ctx context.Context, key string, opts ... defer p.watchLk.Unlock() out := make(chan []byte, 1) - lv, err := p.getLocal(key) + lv, err := p.getLocal(ctx, key) if err == nil { out <- lv close(out) @@ -513,7 +513,7 @@ func (p *PubsubValueStore) handleSubscription(ctx context.Context, ti *topicInfo } ti.dbWriteMx.Lock() - recCmp, err := p.putLocal(ti, key, data) + recCmp, err := p.putLocal(ctx, ti, key, data) ti.dbWriteMx.Unlock() if recCmp > 0 { if err != nil {