diff --git a/bitswap.go b/bitswap.go index 7e50ac2a..14fa1b25 100644 --- a/bitswap.go +++ b/bitswap.go @@ -303,7 +303,7 @@ func (bs *Bitswap) LedgerForPeer(p peer.ID) *decision.Receipt { // resources, provide a context with a reasonably short deadline (ie. not one // that lasts throughout the lifetime of the server) func (bs *Bitswap) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks.Block, error) { - session := bs.sm.NewSession(ctx, bs.provSearchDelay, bs.rebroadcastDelay) + session := bs.sm.GetSession(ctx, bs.provSearchDelay, bs.rebroadcastDelay) return session.GetBlocks(ctx, keys) } @@ -525,12 +525,18 @@ func (bs *Bitswap) IsOnline() bool { return true } -// NewSession generates a new Bitswap session. You should use this, rather -// that calling Bitswap.GetBlocks, any time you intend to do several related -// block requests in a row. The session returned will have it's own GetBlocks -// method, but the session will use the fact that the requests are related to -// be more efficient in its requests to peers. If you are using a session -// from go-blockservice, it will create a bitswap session automatically. +// NewSession generates a new Bitswap session or returns the session associated +// with the passed context (created with exchange.NewSession(ctx)). +// +// You should construct a session either with this function or +// exchange.NewSession instead of repeatedly calling Bitswap.GetBlock(s) any +// time you intend to do several related block requests in a row. The session +// will use the fact that the requests are related to be more efficient in its +// requests to peers. +// +// Note: If you've already wrapped your context with exchange.NewSession, you do +// not need to call this function. When you call the GetBlock(s) functions with +// that context, it will use the associated session. func (bs *Bitswap) NewSession(ctx context.Context) exchange.Fetcher { - return bs.sm.NewSession(ctx, bs.provSearchDelay, bs.rebroadcastDelay) + return bs.sm.GetSession(ctx, bs.provSearchDelay, bs.rebroadcastDelay) } diff --git a/internal/sessionmanager/sessionmanager.go b/internal/sessionmanager/sessionmanager.go index 57eda0fd..d63d89ac 100644 --- a/internal/sessionmanager/sessionmanager.go +++ b/internal/sessionmanager/sessionmanager.go @@ -63,29 +63,46 @@ func New(ctx context.Context, sessionFactory SessionFactory, sessionInterestMana } } -// NewSession initializes a session with the given context, and adds to the -// session manager. -func (sm *SessionManager) NewSession(ctx context.Context, +// GetSession gets the session associated with the context, or creates a new +// one. +func (sm *SessionManager) GetSession(ctx context.Context, provSearchDelay time.Duration, rebroadcastDelay delay.D) exchange.Fetcher { - sessionctx, cancel := context.WithCancel(ctx) - // we just need the id - id, _ := exchange.GetOrCreateSession(context.Background()) + id, sessionctx := exchange.GetOrCreateSession(ctx) + + sm.sessLk.RLock() + s, ok := sm.sessions[id] + sm.sessLk.RUnlock() + + if ok { + return s + } - pm := sm.peerManagerFactory(sessionctx, id) - session := sm.sessionFactory(sessionctx, id, pm, sm.sessionInterestManager, sm.peerManager, sm.blockPresenceManager, sm.notif, provSearchDelay, rebroadcastDelay, sm.self) sm.sessLk.Lock() + defer sm.sessLk.Unlock() + + if s, ok := sm.sessions[id]; ok { + return s + } + + sessionctx, cancel := context.WithCancel(sessionctx) + + pm := sm.peerManagerFactory(sessionctx, id) + session := sm.sessionFactory( + sessionctx, id, pm, + sm.sessionInterestManager, sm.peerManager, sm.blockPresenceManager, + sm.notif, provSearchDelay, rebroadcastDelay, sm.self, + ) sm.sessions[id] = session - sm.sessLk.Unlock() + go func() { defer cancel() select { case <-sm.ctx.Done(): - sm.removeSession(id) - case <-ctx.Done(): - sm.removeSession(id) + case <-sessionctx.Done(): } + sm.removeSession(id) }() return session diff --git a/internal/sessionmanager/sessionmanager_test.go b/internal/sessionmanager/sessionmanager_test.go index 033da0ef..df33508a 100644 --- a/internal/sessionmanager/sessionmanager_test.go +++ b/internal/sessionmanager/sessionmanager_test.go @@ -95,9 +95,9 @@ func TestReceiveFrom(t *testing.T) { p := peer.ID(123) block := blocks.NewBlock([]byte("block")) - firstSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) - secondSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) - thirdSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + firstSession := sm.GetSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + secondSession := sm.GetSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + thirdSession := sm.GetSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) sim.RecordSessionInterest(firstSession.ID(), []cid.Cid{block.Cid()}) sim.RecordSessionInterest(thirdSession.ID(), []cid.Cid{block.Cid()}) @@ -138,9 +138,9 @@ func TestReceiveBlocksWhenManagerContextCancelled(t *testing.T) { p := peer.ID(123) block := blocks.NewBlock([]byte("block")) - firstSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) - secondSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) - thirdSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + firstSession := sm.GetSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + secondSession := sm.GetSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + thirdSession := sm.GetSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) sim.RecordSessionInterest(firstSession.ID(), []cid.Cid{block.Cid()}) sim.RecordSessionInterest(secondSession.ID(), []cid.Cid{block.Cid()}) @@ -173,10 +173,10 @@ func TestReceiveBlocksWhenSessionContextCancelled(t *testing.T) { p := peer.ID(123) block := blocks.NewBlock([]byte("block")) - firstSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + firstSession := sm.GetSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) sessionCtx, sessionCancel := context.WithCancel(ctx) - secondSession := sm.NewSession(sessionCtx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) - thirdSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + secondSession := sm.GetSession(sessionCtx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + thirdSession := sm.GetSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) sim.RecordSessionInterest(firstSession.ID(), []cid.Cid{block.Cid()}) sim.RecordSessionInterest(secondSession.ID(), []cid.Cid{block.Cid()})