diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala index cd3024634e2b2..30b6b6a424969 100755 --- a/core/src/main/scala/kafka/log/LogManager.scala +++ b/core/src/main/scala/kafka/log/LogManager.scala @@ -35,6 +35,7 @@ import scala.jdk.CollectionConverters._ import scala.collection._ import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success, Try} +import kafka.utils.Implicits._ /** * The entry point to the kafka log management subsystem. The log manager is responsible for log creation, retrieval, and cleaning. @@ -478,17 +479,9 @@ class LogManager(logDirs: Seq[File], } try { - for ((dir, dirJobs) <- jobs) { - val hasErrors = dirJobs.exists { future => - Try(future.get) match { - case Success(_) => false - case Failure(e) => - warn(s"There was an error in one of the threads during LogManager shutdown: ${e.getCause}") - true - } - } - - if (!hasErrors) { + jobs.forKeyValue { (dir, dirJobs) => + if (waitForAllToComplete(dirJobs, + e => warn(s"There was an error in one of the threads during LogManager shutdown: ${e.getCause}"))) { val logs = logsInDir(localLogsByDir, dir) // update the last flush point @@ -1167,6 +1160,21 @@ class LogManager(logDirs: Seq[File], object LogManager { + /** + * Wait all jobs to complete + * @param jobs jobs + * @param callback this will be called to handle the exception caused by each Future#get + * @return true if all pass. Otherwise, false + */ + private[log] def waitForAllToComplete(jobs: Seq[Future[_]], callback: Throwable => Unit): Boolean = { + jobs.count(future => Try(future.get) match { + case Success(_) => false + case Failure(e) => + callback(e) + true + }) == 0 + } + val RecoveryPointCheckpointFile = "recovery-point-offset-checkpoint" val LogStartOffsetCheckpointFile = "log-start-offset-checkpoint" val ProducerIdExpirationCheckIntervalMs = 10 * 60 * 1000 diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala index 031000db75522..2970c91bb6fc6 100755 --- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala +++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala @@ -17,14 +17,10 @@ package kafka.log -import java.io._ -import java.nio.file.Files -import java.util.{Collections, Properties} - import com.yammer.metrics.core.MetricName import kafka.metrics.KafkaYammerMetrics -import kafka.server.{FetchDataInfo, FetchLogEnd} import kafka.server.checkpoints.OffsetCheckpointFile +import kafka.server.{FetchDataInfo, FetchLogEnd} import kafka.utils._ import org.apache.directory.api.util.FileUtils import org.apache.kafka.common.errors.OffsetOutOfRangeException @@ -34,8 +30,13 @@ import org.easymock.EasyMock import org.junit.Assert._ import org.junit.{After, Before, Test} import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito import org.mockito.Mockito.{doAnswer, spy} +import java.io._ +import java.nio.file.Files +import java.util.concurrent.Future +import java.util.{Collections, Properties} import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.{Failure, Try} @@ -680,4 +681,34 @@ class LogManagerTest { time.sleep(logConfig.fileDeleteDelayMs + 1) verifyMetrics(1) } + + @Test + def testWaitForAllToComplete(): Unit = { + var invokedCount = 0 + val success: Future[Boolean] = Mockito.mock(classOf[Future[Boolean]]) + Mockito.when(success.get()).thenAnswer { _ => + invokedCount += 1 + true + } + val failure: Future[Boolean] = Mockito.mock(classOf[Future[Boolean]]) + Mockito.when(failure.get()).thenAnswer{ _ => + invokedCount += 1 + throw new RuntimeException + } + + var failureCount = 0 + // all futures should be evaluated + assertFalse(LogManager.waitForAllToComplete(Seq(success, failure), _ => failureCount += 1)) + assertEquals(2, invokedCount) + assertEquals(1, failureCount) + assertFalse(LogManager.waitForAllToComplete(Seq(failure, success), _ => failureCount += 1)) + assertEquals(4, invokedCount) + assertEquals(2, failureCount) + assertTrue(LogManager.waitForAllToComplete(Seq(success, success), _ => failureCount += 1)) + assertEquals(6, invokedCount) + assertEquals(2, failureCount) + assertFalse(LogManager.waitForAllToComplete(Seq(failure, failure), _ => failureCount += 1)) + assertEquals(8, invokedCount) + assertEquals(4, failureCount) + } }