Skip to content

Commit

Permalink
Refactor checkpoint code
Browse files Browse the repository at this point in the history
  • Loading branch information
aserrallerios committed Dec 19, 2019
1 parent 94ec3e2 commit a681c70
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,29 @@

package akka.stream.alpakka.kinesis

import java.time.Instant

import akka.Done
import akka.annotation.InternalApi
import akka.stream.alpakka.kinesis.impl.ShardProcessor
import akka.stream.alpakka.kinesis.CommittableRecord.{BatchData, ShardProcessorData}
import software.amazon.kinesis.lifecycle.ShutdownReason
import software.amazon.kinesis.retrieval.KinesisClientRecord
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber

import scala.util.Try

final class CommittableRecord @InternalApi private[kinesis] (
val shardId: String,
val recordProcessorStartingSequenceNumber: ExtendedSequenceNumber,
val millisBehindLatest: Long,
abstract class CommittableRecord @InternalApi private[kinesis] (
val record: KinesisClientRecord,
recordProcessor: ShardProcessor,
checkpointer: KinesisClientRecord => Unit
val batchData: BatchData,
val processorData: ShardProcessorData
) {

val sequenceNumber: String = record.sequenceNumber()
val subSequenceNumber: Long = record.subSequenceNumber()

def recordProcessorShutdownReason: Option[ShutdownReason] =
recordProcessor.shutdown

def canBeCheckpointed: Boolean =
!recordProcessorShutdownReason.contains(ShutdownReason.LEASE_LOST)

def tryToCheckpoint(): Try[Done] =
Try(checkpointer(record)).map(_ => Done)
def shutdownReason: Option[ShutdownReason]
def canBeCheckpointed: Boolean
def tryToCheckpoint(): Try[Done]

}

Expand All @@ -43,4 +37,16 @@ object CommittableRecord {
// same sequence number but will differ by subsequence number
implicit val orderBySequenceNumber: Ordering[CommittableRecord] =
Ordering[(String, Long)].on(cr (cr.sequenceNumber, cr.subSequenceNumber))

final class ShardProcessorData(
val shardId: String,
val recordProcessorStartingSequenceNumber: ExtendedSequenceNumber,
val pendingCheckpointSequenceNumber: ExtendedSequenceNumber
)
final class BatchData(
val cacheEntryTime: Instant,
val cacheExitTime: Instant,
val isAtShardEnd: Boolean,
val millisBehindLatest: Long
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private[kinesis] class KinesisSchedulerSourceStage(
inheritedAttributes: Attributes
): (GraphStageLogic, Future[Scheduler]) = {
val matValue = Promise[Scheduler]()
(new Logic(matValue), matValue.future)
new Logic(matValue) -> matValue.future
}

class Logic(matValue: Promise[Scheduler]) extends GraphStageLogic(shape) with StageLogging with OutHandler {
Expand All @@ -68,12 +68,13 @@ private[kinesis] class KinesisSchedulerSourceStage(

override def preStart(): Unit = {
self = getStageActor(awaitingRecords)
val newRecordCallback: CommittableRecord => Unit = {
semaphore.tryAcquire(backpressureTimeout.length, backpressureTimeout.unit)
self.ref ! NewRecord(_)
}
scheduler = schedulerBuilder(new ShardRecordProcessorFactory {
override def shardRecordProcessor(): ShardRecordProcessor =
new ShardProcessor(record => {
semaphore.tryAcquire(backpressureTimeout.length, backpressureTimeout.unit)
self.ref ! NewRecord(record)
})
new ShardProcessor(newRecordCallback)
})
Future(scheduler.run()).onComplete(self.ref ! SchedulerShutdown(_))
matValue.success(scheduler)
Expand All @@ -97,6 +98,6 @@ private[kinesis] class KinesisSchedulerSourceStage(
}
override def onPull(): Unit = self.ref ! Pump
override def onDownstreamFinish(): Unit = self.ref ! Complete
override def postStop(): Unit = Future(scheduler.shutdown())
override def postStop(): Unit = Future(if (!scheduler.shutdownComplete()) scheduler.shutdown())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,84 +6,91 @@ package akka.stream.alpakka.kinesis.impl

import java.util.concurrent.Semaphore

import akka.Done
import akka.annotation.InternalApi
import akka.stream.alpakka.kinesis.CommittableRecord
import akka.stream.alpakka.kinesis.CommittableRecord.{BatchData, ShardProcessorData}
import software.amazon.kinesis.lifecycle.ShutdownReason
import software.amazon.kinesis.lifecycle.events._
import software.amazon.kinesis.processor.ShardRecordProcessor
import software.amazon.kinesis.processor.{RecordProcessorCheckpointer, ShardRecordProcessor}
import software.amazon.kinesis.retrieval.KinesisClientRecord
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber

import scala.collection.JavaConverters._
import scala.util.Try

@InternalApi
private[kinesis] class ShardProcessor(
callback: CommittableRecord => Unit
) extends ShardRecordProcessor {

private var shardId: String = _
private var extendedSequenceNumber: ExtendedSequenceNumber = _

private val semaphore = new Semaphore(1)

var shutdown: Option[ShutdownReason] = None
private var shardData: ShardProcessorData = _
private var checkpointer: RecordProcessorCheckpointer = _
private var shutdown: Option[ShutdownReason] = None

override def initialize(initializationInput: InitializationInput): Unit = {
shardId = initializationInput.shardId()
extendedSequenceNumber = initializationInput.extendedSequenceNumber()
}
override def initialize(initializationInput: InitializationInput): Unit =
shardData = new ShardProcessorData(initializationInput.shardId,
initializationInput.extendedSequenceNumber,
initializationInput.pendingCheckpointSequenceNumber)

override def processRecords(processRecordsInput: ProcessRecordsInput): Unit = {
if (processRecordsInput.isAtShardEnd) {
checkpointer = processRecordsInput.checkpointer()

val batchData = new BatchData(processRecordsInput.cacheEntryTime,
processRecordsInput.cacheExitTime,
processRecordsInput.isAtShardEnd,
processRecordsInput.millisBehindLatest)

if (batchData.isAtShardEnd) {
semaphore.acquire()
}

// This implementation will try to checkpoint every Record with the original
// checkpointer. Other option would be to keep a reference of the latest
// checkpointer passed to this instance using any of these methods:
// * processRecords
// * shutdownRequested
// * shardEnded
val checkpoint = (record: KinesisClientRecord) =>
processRecordsInput.checkpointer().checkpoint(record.sequenceNumber(), record.subSequenceNumber())
val checkpointAndRelease = checkpoint andThen (_ => semaphore.release())
val numberOfRecords = processRecordsInput.records().size()

processRecordsInput.records().asScala.zipWithIndex.foreach {
case (record, index) =>
callback(
new CommittableRecord(
shardId,
extendedSequenceNumber,
processRecordsInput.millisBehindLatest(),
new InternalCommittableRecord(
record,
recordProcessor = this,
if (processRecordsInput.isAtShardEnd && index + 1 == numberOfRecords) {
checkpointAndRelease
} else {
checkpoint
}
batchData,
isLatestRecord = processRecordsInput.isAtShardEnd && index + 1 == numberOfRecords
)
)
}
}

final class InternalCommittableRecord(record: KinesisClientRecord, batchData: BatchData, isLatestRecord: Boolean)
extends CommittableRecord(record, batchData, shardData) {
private def checkpoint(): Unit =
checkpointer.checkpoint(record.sequenceNumber(), record.subSequenceNumber())
private def checkpointAndRelease(): Unit = { checkpoint(); semaphore.release() }

override def shutdownReason: Option[ShutdownReason] = shutdown
override def canBeCheckpointed: Boolean =
!shutdownReason.contains(ShutdownReason.LEASE_LOST)
override def tryToCheckpoint(): Try[Done] =
Try(if (isLatestRecord) checkpointAndRelease() else checkpoint()).map(_ => Done)
}

override def leaseLost(leaseLostInput: LeaseLostInput): Unit =
// We cannot checkpoint at this point as we don't have the
// lease anymore
shutdown = Some(ShutdownReason.LEASE_LOST)

override def shardEnded(shardEndedInput: ShardEndedInput): Unit = {
checkpointer = shardEndedInput.checkpointer()
// We must checkpoint to finish the shard, but we wait
// until all records in flight have been processed
shutdown = Some(ShutdownReason.SHARD_END)
semaphore.acquire()
shardEndedInput.checkpointer.checkpoint()
checkpointer.checkpoint()
}

override def shutdownRequested(shutdownInput: ShutdownRequestedInput): Unit =
override def shutdownRequested(shutdownInput: ShutdownRequestedInput): Unit = {
checkpointer = shutdownInput.checkpointer()
// We don't checkpoint at this point as we assume the
// standard mechanism will checkpoint when required
shutdown = Some(ShutdownReason.REQUESTED)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

package akka.stream.alpakka.kinesis.javadsl

import java.util.concurrent.{CompletionStage, Executor}
import java.util.concurrent.CompletionStage

import akka.NotUsed
import akka.stream.alpakka.kinesis.{CommittableRecord, scaladsl, _}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ object KinesisSchedulerSource {
Future[Scheduler]
]] =
apply(schedulerBuilder, settings)
.groupBy(MAX_KINESIS_SHARDS, _.shardId)
.groupBy(MAX_KINESIS_SHARDS, _.processorData.shardId)

def checkpointRecordsFlow(
settings: KinesisSchedulerCheckpointSettings = KinesisSchedulerCheckpointSettings.defaultInstance
): Flow[CommittableRecord, KinesisClientRecord, NotUsed] =
Flow[CommittableRecord]
.groupBy(MAX_KINESIS_SHARDS, _.shardId)
.groupBy(MAX_KINESIS_SHARDS, _.processorData.shardId)
.groupedWithin(settings.maxBatchSize, settings.maxBatchWait)
.via(GraphDSL.create() { implicit b =>
import GraphDSL.Implicits._
Expand All @@ -76,12 +76,11 @@ object KinesisSchedulerSource {
case record if record.canBeCheckpointed =>
record
.tryToCheckpoint()
.recover {
.recover({
case _: ShutdownException => Done
}
})
.get
case _ =>
Done
case _ => Done
}
.addAttributes(Attributes(ActorAttributes.IODispatcher))
) ~> join.in0
Expand Down
Loading

0 comments on commit a681c70

Please sign in to comment.