diff --git a/pkg/varlogtest/log.go b/pkg/varlogtest/log.go index e2591f4ca..f7cc50b0f 100644 --- a/pkg/varlogtest/log.go +++ b/pkg/varlogtest/log.go @@ -16,10 +16,26 @@ import ( type testLog struct { vt *VarlogTest closed bool + + lsaPool struct { + lsaMap map[int64]*logStreamAppender + nextID int64 + mu sync.Mutex + + wg sync.WaitGroup + } } var _ varlog.Log = (*testLog)(nil) +func newTestLog(vt *VarlogTest) *testLog { + c := &testLog{ + vt: vt, + } + c.lsaPool.lsaMap = make(map[int64]*logStreamAppender) + return c +} + func (c *testLog) lock() error { c.vt.cond.L.Lock() if c.closed { @@ -38,6 +54,14 @@ func (c *testLog) Close() error { defer c.vt.cond.L.Unlock() c.closed = true c.vt.cond.Broadcast() + + c.lsaPool.mu.Lock() + defer c.lsaPool.mu.Unlock() + for _, lsa := range c.lsaPool.lsaMap { + lsa.Close() + } + + c.lsaPool.wg.Wait() return nil } @@ -342,9 +366,22 @@ func (c *testLog) NewLogStreamAppender(tpid types.TopicID, lsid types.LogStreamI lsa.queue.ch = make(chan *queueEntry, pipelineSize) lsa.queue.cv = sync.NewCond(&lsa.queue.mu) - lsa.wg.Add(1) + c.lsaPool.mu.Lock() + defer c.lsaPool.mu.Unlock() + id := c.lsaPool.nextID + c.lsaPool.nextID++ + c.lsaPool.lsaMap[id] = lsa + + c.lsaPool.wg.Add(1) go func() { - defer lsa.wg.Done() + defer func() { + c.lsaPool.wg.Done() + + c.lsaPool.mu.Lock() + defer c.lsaPool.mu.Unlock() + delete(c.lsaPool.lsaMap, id) + }() + for qe := range lsa.queue.ch { qe.callback(qe.result.Metadata, qe.result.Err) lsa.queue.cv.L.Lock() @@ -377,8 +414,6 @@ type logStreamAppender struct { cv *sync.Cond mu sync.Mutex } - - wg sync.WaitGroup } var _ varlog.LogStreamAppender = (*logStreamAppender)(nil) @@ -417,7 +452,6 @@ func (lsa *logStreamAppender) Close() { } lsa.closed.value = true close(lsa.queue.ch) - lsa.wg.Wait() } type errSubscriber struct { diff --git a/pkg/varlogtest/varlogtest.go b/pkg/varlogtest/varlogtest.go index 1fee01a50..f19e88c0b 100644 --- a/pkg/varlogtest/varlogtest.go +++ b/pkg/varlogtest/varlogtest.go @@ -73,7 +73,7 @@ func (vt *VarlogTest) NewAdminClient() varlog.Admin { } func (vt *VarlogTest) NewLogClient() varlog.Log { - return &testLog{vt: vt} + return newTestLog(vt) } func (vt *VarlogTest) generateTopicID() types.TopicID { diff --git a/pkg/varlogtest/varlogtest_test.go b/pkg/varlogtest/varlogtest_test.go index e1bbfd27b..9cb838421 100644 --- a/pkg/varlogtest/varlogtest_test.go +++ b/pkg/varlogtest/varlogtest_test.go @@ -198,6 +198,39 @@ func TestVarlotTest_LogStreamAppender(t *testing.T) { require.NoError(t, err) }, }, + { + name: "CloseInCallback", + testf: func(t *testing.T, vadm varlog.Admin, vcli varlog.Log, tpid types.TopicID, lsid types.LogStreamID) { + lsa, err := vcli.NewLogStreamAppender(tpid, lsid) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + err = lsa.AppendBatch([][]byte{[]byte("foo")}, func(lem []varlogpb.LogEntryMeta, err error) { + defer wg.Done() + assert.NoError(t, err) + lsa.Close() + }) + require.NoError(t, err) + wg.Wait() + }, + }, + { + name: "DoesNotCloseLogStreamAppender", + testf: func(t *testing.T, vadm varlog.Admin, vcli varlog.Log, tpid types.TopicID, lsid types.LogStreamID) { + // Closing the log client will shut down the log stream appender forcefully. + lsa, err := vcli.NewLogStreamAppender(tpid, lsid) + require.NoError(t, err) + + cb := func(_ []varlogpb.LogEntryMeta, err error) { + assert.NoError(t, err) + } + for i := 0; i < numLogs; i++ { + err := lsa.AppendBatch([][]byte{[]byte("foo")}, cb) + require.NoError(t, err) + } + }, + }, { name: "Manager", testf: func(t *testing.T, vadm varlog.Admin, vcli varlog.Log, tpid types.TopicID, lsid types.LogStreamID) {