Skip to content

Commit

Permalink
[SPARK-10891][STREAMING][KINESIS] Add MessageHandler to KinesisUtils.…
Browse files Browse the repository at this point in the history
…createStream similar to Direct Kafka

This PR allows users to map a Kinesis `Record` to a generic `T` when creating a Kinesis stream. This is particularly useful, if you would like to do extra work with Kinesis metadata such as sequence number, and partition key.

TODO:
 - [x] add tests

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #8954 from brkyvz/kinesis-handler.
  • Loading branch information
brkyvz authored and tdas committed Oct 26, 2015
1 parent 80279ac commit 63accc7
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.streaming.kinesis

import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
Expand Down Expand Up @@ -67,16 +68,17 @@ class KinesisBackedBlockRDDPartition(
* sequence numbers of the corresponding blocks.
*/
private[kinesis]
class KinesisBackedBlockRDD(
class KinesisBackedBlockRDD[T: ClassTag](
@transient sc: SparkContext,
val regionName: String,
val endpointUrl: String,
@transient blockIds: Array[BlockId],
@transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges],
@transient isBlockIdValid: Array[Boolean] = Array.empty,
val retryTimeoutMs: Int = 10000,
val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _,
val awsCredentialsOption: Option[SerializableAWSCredentials] = None
) extends BlockRDD[Array[Byte]](sc, blockIds) {
) extends BlockRDD[T](sc, blockIds) {

require(blockIds.length == arrayOfseqNumberRanges.length,
"Number of blockIds is not equal to the number of sequence number ranges")
Expand All @@ -90,23 +92,23 @@ class KinesisBackedBlockRDD(
}
}

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition]
val blockId = partition.blockId

def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = {
def getBlockFromBlockManager(): Option[Iterator[T]] = {
logDebug(s"Read partition data of $this from block manager, block $blockId")
blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]])
blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
}

def getBlockFromKinesis(): Iterator[Array[Byte]] = {
val credenentials = awsCredentialsOption.getOrElse {
def getBlockFromKinesis(): Iterator[T] = {
val credentials = awsCredentialsOption.getOrElse {
new DefaultAWSCredentialsProviderChain().getCredentials()
}
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
new KinesisSequenceRangeIterator(
credenentials, endpointUrl, regionName, range, retryTimeoutMs)
new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
range, retryTimeoutMs).map(messageHandler)
}
}
if (partition.isBlockIdValid) {
Expand All @@ -129,8 +131,7 @@ class KinesisSequenceRangeIterator(
endpointUrl: String,
regionId: String,
range: SequenceNumberRange,
retryTimeoutMs: Int
) extends NextIterator[Array[Byte]] with Logging {
retryTimeoutMs: Int) extends NextIterator[Record] with Logging {

private val client = new AmazonKinesisClient(credentials)
private val streamName = range.streamName
Expand All @@ -142,8 +143,8 @@ class KinesisSequenceRangeIterator(

client.setEndpoint(endpointUrl, "kinesis", regionId)

override protected def getNext(): Array[Byte] = {
var nextBytes: Array[Byte] = null
override protected def getNext(): Record = {
var nextRecord: Record = null
if (toSeqNumberReceived) {
finished = true
} else {
Expand All @@ -170,10 +171,7 @@ class KinesisSequenceRangeIterator(
} else {

// Get the record, copy the data into a byte array and remember its sequence number
val nextRecord: Record = internalIterator.next()
val byteBuffer = nextRecord.getData()
nextBytes = new Array[Byte](byteBuffer.remaining())
byteBuffer.get(nextBytes)
nextRecord = internalIterator.next()
lastSeqNumber = nextRecord.getSequenceNumber()

// If the this record's sequence number matches the stopping sequence number, then make sure
Expand All @@ -182,9 +180,8 @@ class KinesisSequenceRangeIterator(
toSeqNumberReceived = true
}
}

}
nextBytes
nextRecord
}

override protected def close(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.streaming.kinesis

import scala.reflect.ClassTag

import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record

import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, StorageLevel}
Expand All @@ -26,7 +29,7 @@ import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
import org.apache.spark.streaming.{Duration, StreamingContext, Time}

private[kinesis] class KinesisInputDStream(
private[kinesis] class KinesisInputDStream[T: ClassTag](
@transient _ssc: StreamingContext,
streamName: String,
endpointUrl: String,
Expand All @@ -35,11 +38,12 @@ private[kinesis] class KinesisInputDStream(
checkpointAppName: String,
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: Record => T,
awsCredentialsOption: Option[SerializableAWSCredentials]
) extends ReceiverInputDStream[Array[Byte]](_ssc) {
) extends ReceiverInputDStream[T](_ssc) {

private[streaming]
override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = {
override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = {

// This returns true even for when blockInfos is empty
val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty)
Expand All @@ -56,6 +60,7 @@ private[kinesis] class KinesisInputDStream(
context.sc, regionName, endpointUrl, blockIds, seqNumRanges,
isBlockIdValid = isBlockIdValid,
retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
messageHandler = messageHandler,
awsCredentialsOption = awsCredentialsOption)
} else {
logWarning("Kinesis sequence number information was not present with some block metadata," +
Expand All @@ -64,8 +69,8 @@ private[kinesis] class KinesisInputDStream(
}
}

override def getReceiver(): Receiver[Array[Byte]] = {
override def getReceiver(): Receiver[T] = {
new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption)
checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,17 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
* @param awsCredentialsOption Optional AWS credentials, used when user directly specifies
* the credentials
*/
private[kinesis] class KinesisReceiver(
private[kinesis] class KinesisReceiver[T](
val streamName: String,
endpointUrl: String,
regionName: String,
initialPositionInStream: InitialPositionInStream,
checkpointAppName: String,
checkpointInterval: Duration,
storageLevel: StorageLevel,
awsCredentialsOption: Option[SerializableAWSCredentials]
) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver =>
messageHandler: Record => T,
awsCredentialsOption: Option[SerializableAWSCredentials])
extends Receiver[T](storageLevel) with Logging { receiver =>

/*
* =================================================================================
Expand Down Expand Up @@ -202,12 +203,7 @@ private[kinesis] class KinesisReceiver(
/** Add records of the given shard to the current block being generated */
private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = {
if (records.size > 0) {
val dataIterator = records.iterator().asScala.map { record =>
val byteBuffer = record.getData()
val byteArray = new Array[Byte](byteBuffer.remaining())
byteBuffer.get(byteArray)
byteArray
}
val dataIterator = records.iterator().asScala.map(messageHandler)
val metadata = SequenceNumberRange(streamName, shardId,
records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber())
blockGenerator.addMultipleDataWithCallback(dataIterator, metadata)
Expand Down Expand Up @@ -240,7 +236,7 @@ private[kinesis] class KinesisReceiver(

/** Store the block along with its associated ranges */
private def storeBlockWithRanges(
blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = {
blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = {
val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId)
if (rangesToReportOption.isEmpty) {
stop("Error while storing block into Spark, could not find sequence number ranges " +
Expand Down Expand Up @@ -325,7 +321,7 @@ private[kinesis] class KinesisReceiver(
/** Callback method called when a block is ready to be pushed / stored. */
def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
storeBlockWithRanges(blockId,
arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]])
arrayBuffer.asInstanceOf[mutable.ArrayBuffer[T]])
}

/** Callback called in case of any error in internal of the BlockGenerator */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ import org.apache.spark.Logging
* @param checkpointState represents the checkpoint state including the next checkpoint time.
* It's injected here for mocking purposes.
*/
private[kinesis] class KinesisRecordProcessor(
receiver: KinesisReceiver,
private[kinesis] class KinesisRecordProcessor[T](
receiver: KinesisReceiver[T],
workerId: String,
checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging {

Expand Down
Loading

0 comments on commit 63accc7

Please sign in to comment.