Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Skip features #1017

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession,
val anchorDFMap1 = anchorToDataSourceMapper.getBasicAnchorDFMapForJoin(ss, Seq(featureAnchor), failOnMissingPartition)
val featureInfo = FeatureTransformation.directCalculate(
anchorGroup: AnchorFeatureGroups,
anchorDFMap1(featureAnchor),
anchorDFMap1(featureAnchor).get,
featureAnchor.featureAnchor.sourceKeyExtractor,
None,
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object DataSourceNodeEvaluator extends NodeEvaluator{
val timeStampExpr = constructTimeStampExpr(timeWindowParam.timestampColumn, timeWindowParam.timestampColumnFormat)
val needTimestampColumn = if (dataSourceNode.hasTimestampColumnInfo) false else true
val dataSourceAccessor = DataSourceAccessor(ss, source, timeRange, None, failOnMissingPartition = false, needTimestampColumn, dataPathHandlers = dataPathHandlers)
val sourceDF = dataSourceAccessor.get()
val sourceDF = dataSourceAccessor.get.get()
val (df, keyExtractor, timestampExpr) = if (dataSourceNode.getKeyExpressionType == KeyExpressionType.UDF) {
val className = Class.forName(dataSourceNode.getKeyExpression())
val keyExtractorClass = className.newInstance match {
Expand Down Expand Up @@ -110,7 +110,7 @@ object DataSourceNodeEvaluator extends NodeEvaluator{
// Augment time information also here. Table node should not have time info?
val dataSource = com.linkedin.feathr.offline.source.DataSource(path, SourceFormatType.FIXED_PATH)
val dataSourceAccessor = DataSourceAccessor(ss, dataSource, None, None, failOnMissingPartition = false, dataPathHandlers = dataPathHandlers)
val sourceDF = dataSourceAccessor.get()
val sourceDF = dataSourceAccessor.get.get()
val (df, keyExtractor) = if (dataSourceNode.getKeyExpressionType == KeyExpressionType.UDF) {
val className = Class.forName(dataSourceNode.getKeyExpression())
className.newInstance match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ object FeatureGenJob {
// For example, f1, f2 belongs to anchor. Then Map("f1,f2"-> anchor)
val dataFrameMapForPreprocessing = anchorsWithSource
.filter(x => featureNamesInAnchorSet.contains(x._1.featureAnchor.features.toSeq.sorted.mkString(",")))
.map(x => (x._1.featureAnchor.features.toSeq.sorted.mkString(","), x._2.get()))
.map(x => (x._1.featureAnchor.features.toSeq.sorted.mkString(","), x._2.get.get()))

// Pyspark only understand Java map so we need to convert Scala map back to Java map.
dataFrameMapForPreprocessing.asJava
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ object FeatureJoinJob {
// For example, f1, f2 belongs to anchor. Then Map("f1,f2"-> anchor)
val dataFrameMapForPreprocessing = anchorsWithSource
.filter(x => featureNamesInAnchorSet.contains(x._1.featureAnchor.features.toSeq.sorted.mkString(",")))
.map(x => (x._1.featureAnchor.features.toSeq.sorted.mkString(","), x._2.get()))
.map(x => (x._1.featureAnchor.features.toSeq.sorted.mkString(","), x._2.get.get()))

// Pyspark only understand Java map so we need to convert Scala map back to Java map.
dataFrameMapForPreprocessing.asJava
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d
.map(featureGroups.allAnchoredFeatures),
failOnMissingPartition)

val updatedAnchorSourceAccessorMap = anchorSourceAccessorMap.filter(x => x._2.isDefined).map(x => x._1 -> x._2.get)

implicit val joinExecutionContext: JoinExecutionContext =
JoinExecutionContext(ss, logicalPlan, featureGroups, bloomFilters, Some(saltedJoinFrequentItemDFs))
// 3. Join sliding window aggregation features
Expand All @@ -210,7 +212,7 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d
SparkJoinWithJoinCondition(EqualityJoinConditionBuilder), mvelContext)
}
val FeatureDataFrameOutput(FeatureDataFrame(withAllBasicAnchoredFeatureDF, inferredBasicAnchoredFeatureTypes)) =
anchoredFeatureJoinStep.joinFeatures(requiredRegularFeatureAnchors, AnchorJoinStepInput(withWindowAggFeatureDF, anchorSourceAccessorMap))
anchoredFeatureJoinStep.joinFeatures(requiredRegularFeatureAnchors, AnchorJoinStepInput(withWindowAggFeatureDF, updatedAnchorSourceAccessorMap))
// 5. If useSlickJoin, restore(join back) all observation fields before we evaluate post derived features, sequential join and passthrough
// anchored features, as they might require other columns in the original observation data, while the current observation
// dataset does not have these fields (were removed in the preProcessObservation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import com.linkedin.feathr.offline.source.dataloader.DataLoaderFactory
import com.linkedin.feathr.offline.source.pathutil.{PathChecker, TimeBasedHdfsPathAnalyzer}
import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler
import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType}
import com.linkedin.feathr.offline.util.PartitionLimiter
import com.linkedin.feathr.offline.util.{FeathrUtils, PartitionLimiter}
import com.linkedin.feathr.offline.util.datetime.DateTimeInterval
import org.apache.spark.sql.{DataFrame, SparkSession}

Expand Down Expand Up @@ -48,21 +48,22 @@ private[offline] object DataSourceAccessor {
failOnMissingPartition: Boolean,
addTimestampColumn: Boolean = false,
isStreaming: Boolean = false,
dataPathHandlers: List[DataPathHandler]): DataSourceAccessor = { //TODO: Add tests
dataPathHandlers: List[DataPathHandler]): Option[DataSourceAccessor] = { //TODO: Add tests

val dataAccessorHandlers: List[DataAccessorHandler] = dataPathHandlers.map(_.dataAccessorHandler)
val dataLoaderHandlers: List[DataLoaderHandler] = dataPathHandlers.map(_.dataLoaderHandler)

val sourceType = source.sourceType
val dataLoaderFactory = DataLoaderFactory(ss, isStreaming, dataLoaderHandlers)
val skipFeature = FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf, FeathrUtils.SKIP_MISSING_FEATURE).toBoolean
if (isStreaming) {
new StreamDataSourceAccessor(ss, dataLoaderFactory, source)
Some(new StreamDataSourceAccessor(ss, dataLoaderFactory, source))
} else if (dateIntervalOpt.isEmpty || sourceType == SourceFormatType.FIXED_PATH || sourceType == SourceFormatType.LIST_PATH) {
// if no input interval, or the path is fixed or list, load whole dataset
new NonTimeBasedDataSourceAccessor(ss, dataLoaderFactory, source, expectDatumType)
Some(new NonTimeBasedDataSourceAccessor(ss, dataLoaderFactory, source, expectDatumType))
} else {
import scala.util.control.Breaks._

val timeInterval = dateIntervalOpt.get
var dataAccessorOpt: Option[DataSourceAccessor] = None
breakable {
Expand All @@ -74,8 +75,12 @@ private[offline] object DataSourceAccessor {
}
}
val dataAccessor = dataAccessorOpt match {
case Some(dataAccessor) => dataAccessor
case _ => createFromHdfsPath(ss, source, timeInterval, expectDatumType, failOnMissingPartition, addTimestampColumn, dataLoaderHandlers)
case Some(dataAccessor) => dataAccessorOpt
case _ => try {
Some(createFromHdfsPath(ss, source, timeInterval, expectDatumType, failOnMissingPartition, addTimestampColumn, dataLoaderHandlers))
} catch {
case e: Exception => if (!skipFeature) throw e else None
}
}
dataAccessor
}
Expand Down Expand Up @@ -106,6 +111,7 @@ private[offline] object DataSourceAccessor {
val partitionLimiter = new PartitionLimiter(ss)
val pathAnalyzer = new TimeBasedHdfsPathAnalyzer(pathChecker, dataLoaderHandlers)
val fileName = new File(source.path).getName
val skipFeature = FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf, FeathrUtils.ENABLE_SALTED_JOIN).toBoolean
if (source.timePartitionPattern.isDefined) {
// case 1: the timePartitionPattern exists
val pathInfo = pathAnalyzer.analyze(source.path, source.timePartitionPattern.get)
Expand All @@ -117,7 +123,8 @@ private[offline] object DataSourceAccessor {
source,
timeInterval,
failOnMissingPartition,
addTimestampColumn)
addTimestampColumn,
skipFeature)
} else {
// legacy configurations without timePartitionPattern
if (fileName.endsWith("daily") || fileName.endsWith("hourly") || source.sourceType == SourceFormatType.TIME_PATH) {
Expand All @@ -131,7 +138,8 @@ private[offline] object DataSourceAccessor {
source,
timeInterval,
failOnMissingPartition,
addTimestampColumn)
addTimestampColumn,
skipFeature)
} else {
// case 3: load as whole dataset
new NonTimeBasedDataSourceAccessor(ss, fileLoaderFactory, source, expectDatumType)
Expand All @@ -147,7 +155,7 @@ private[offline] object DataSourceAccessor {
*/
private[offline] case class DataAccessorHandler(
validatePath: String => Boolean,
getAccessor:
getAccessor:
(
SparkSession,
DataSource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.linkedin.feathr.offline.source.dataloader.DataLoaderFactory
import com.linkedin.feathr.offline.source.pathutil.{PathChecker, PathInfo, TimeBasedHdfsPathGenerator}
import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils
import com.linkedin.feathr.offline.transformation.DataFrameExt._
import com.linkedin.feathr.offline.util.PartitionLimiter
import com.linkedin.feathr.offline.util.{FeathrUtils, PartitionLimiter}
import com.linkedin.feathr.offline.util.datetime.{DateTimeInterval, OfflineDateTimeUtils}
import org.apache.log4j.Logger
import org.apache.spark.sql.DataFrame
Expand Down Expand Up @@ -122,6 +122,7 @@ private[offline] object PathPartitionedTimeSeriesSourceAccessor {
* @param timeInterval timespan of dataset
* @param failOnMissingPartition whether to fail the file loading if some of the date partitions are missing.
* @param addTimestampColumn whether to create a timestamp column from the time partition of the source.
* @param skipFeature if feature data is not present, boolean var to see if this feature should be skipped.
* @return a TimeSeriesSource
*/
def apply(
Expand All @@ -132,23 +133,25 @@ private[offline] object PathPartitionedTimeSeriesSourceAccessor {
source: DataSource,
timeInterval: DateTimeInterval,
failOnMissingPartition: Boolean,
addTimestampColumn: Boolean): DataSourceAccessor = {
addTimestampColumn: Boolean,
skipFeature: Boolean): DataSourceAccessor = {
val pathGenerator = new TimeBasedHdfsPathGenerator(pathChecker)
val dateTimeResolution = pathInfo.dateTimeResolution
val postPath = source.postPath
val postfixPath = if(postPath.isEmpty || postPath.startsWith("/")) postPath else "/" + postPath
val pathList = pathGenerator.generate(pathInfo, timeInterval, !failOnMissingPartition, postfixPath)
val timeFormatString = pathInfo.datePathPattern

val dataframes = pathList.map(path => {
val timeStr = path.substring(path.length - (timeFormatString.length + postfixPath.length), path.length - postfixPath.length)
val time = OfflineDateTimeUtils.createTimeFromString(timeStr, timeFormatString)
val interval = DateTimeInterval.createFromInclusive(time, time, dateTimeResolution)

val df = fileLoaderFactory.create(path).loadDataFrame()
(df, interval)
})

if (dataframes.isEmpty) {

if (dataframes.isEmpty && !skipFeature) {
val errMsg = s"Input data is empty for creating TimeSeriesSource. No available " +
s"date partition exist in HDFS for path ${pathInfo.basePath} between ${timeInterval.getStart} and ${timeInterval.getEnd} "
val errMsgPf = errMsg + s"with postfix path ${postfixPath}"
Expand All @@ -160,6 +163,7 @@ private[offline] object PathPartitionedTimeSeriesSourceAccessor {
ErrorLabel.FEATHR_USER_ERROR, errMsgPf)
}
}

val datePartitions = dataframes.map {
case (df, interval) =>
DatePartition(df, interval)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ private[offline] class SlidingWindowAggregationJoiner(
case (source, grouped) => (source, grouped.map(_._2))
})

val notJoinedFeatures = new mutable.HashSet[String]()

// For each source, we calculate the maximum window duration that needs to be loaded across all
// required SWA features defined on this source.
// Then we load the source only once.
Expand All @@ -140,10 +142,13 @@ private[offline] class SlidingWindowAggregationJoiner(
maxDurationPerSource,
featuresToDelayImmutableMap.values.toArray,
failOnMissingPartition)

if (originalSourceDf.isEmpty) {
res.map(notJoinedFeatures.add)
anchors.map(anchor => (anchor, originalSourceDf))
} else {
val sourceDF: DataFrame = preprocessedDf match {
case Some(existDf) => existDf
case None => originalSourceDf
case None => originalSourceDf.get
}

// all the anchors here have same key sourcekey extractor, so we just use the first one to generate key column and share
Expand All @@ -155,13 +160,17 @@ private[offline] class SlidingWindowAggregationJoiner(
case keyExtractor => keyExtractor.appendKeyColumns(sourceDF)
}

anchors.map(anchor => (anchor, withKeyDF))
})
anchors.map(anchor => (anchor, Some(withKeyDF)))
}}
)

val updatedWindowAggAnchorDFMap = windowAggAnchorDFMap.filter(x => x._2.isDefined).map(x => x._1 ->x._2.get)

val allInferredFeatureTypes = mutable.Map.empty[String, FeatureTypeConfig]

windowAggFeatureStages.foreach({
case (keyTags: Seq[Int], featureNames: Seq[String]) =>
if (!featureNames.diff(notJoinedFeatures.toSeq).isEmpty) {
val stringKeyTags = keyTags.map(keyTagList).map(k => s"CAST (${k} AS string)") // restore keyTag to column names in join config

// get the bloom filter for the key combinations in this stage
Expand All @@ -188,10 +197,10 @@ private[offline] class SlidingWindowAggregationJoiner(
s"${labelDataDef.dataSource.collect().take(3).map(_.toString()).mkString("\n ")}")
}
val windowAggAnchorsThisStage = featureNames.map(allWindowAggFeatures)
val windowAggAnchorDFThisStage = windowAggAnchorDFMap.filterKeys(windowAggAnchorsThisStage.toSet)
val windowAggAnchorDFThisStage = updatedWindowAggAnchorDFMap.filterKeys(windowAggAnchorsThisStage.toSet)

val factDataDefs =
SlidingWindowFeatureUtils.getSWAAnchorGroups(windowAggAnchorDFThisStage).map {
SlidingWindowFeatureUtils.getSWAAnchorGroups(updatedWindowAggAnchorDFMap).map {
anchorWithSourceToDFMap =>
val selectedFeatures = anchorWithSourceToDFMap.keySet.flatMap(_.selectedFeatures).filter(featureNames.contains(_))
val factData = anchorWithSourceToDFMap.head._2
Expand Down Expand Up @@ -244,7 +253,7 @@ private[offline] class SlidingWindowAggregationJoiner(
s"${factDataDef.dataSource.collect().take(3).map(_.toString()).mkString("\n ")}")
}
}
})
}})
offline.FeatureDataFrame(contextDF, allInferredFeatureTypes.toMap)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
def getBasicAnchorDFMapForJoin(
ss: SparkSession,
requiredFeatureAnchors: Seq[FeatureAnchorWithSource],
failOnMissingPartition: Boolean): Map[FeatureAnchorWithSource, DataSourceAccessor] = {
failOnMissingPartition: Boolean): Map[FeatureAnchorWithSource, Option[DataSourceAccessor]] = {
// get a Map from each source to a list of all anchors based on this source
val sourceToAnchor = requiredFeatureAnchors
.map(anchor => (anchor.source, anchor))
Expand Down Expand Up @@ -63,12 +63,12 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
}
}
val timeSeriesSource = DataSourceAccessor(ss = ss,
source = source,
dateIntervalOpt = dateInterval,
expectDatumType = Some(expectDatumType),
source = source,
dateIntervalOpt = dateInterval,
expectDatumType = Some(expectDatumType),
failOnMissingPartition = failOnMissingPartition,
dataPathHandlers = dataPathHandlers)

anchorsWithDate.map(anchor => (anchor, timeSeriesSource))
})
}
Expand All @@ -91,7 +91,7 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
obsTimeRange: DateTimeInterval,
window: Duration,
timeDelays: Array[Duration],
failOnMissingPartition: Boolean): DataFrame = {
failOnMissingPartition: Boolean): Option[DataFrame] = {

val dataLoaderHandlers: List[DataLoaderHandler] = dataPathHandlers.map(_.dataLoaderHandler)

Expand Down Expand Up @@ -119,7 +119,7 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
failOnMissingPartition = failOnMissingPartition,
addTimestampColumn = needCreateTimestampColumn,
dataPathHandlers = dataPathHandlers)
timeSeriesSource.get()
if (timeSeriesSource.isDefined) Some(timeSeriesSource.get.get()) else None
}

/**
Expand Down Expand Up @@ -171,8 +171,8 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
addTimestampColumn = needCreateTimestampColumn,
isStreaming = isStreaming,
dataPathHandlers = dataPathHandlers)
anchors.map(anchor => (anchor, timeSeriesSource))

anchors.map(anchor => (anchor, timeSeriesSource.get))
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ private[offline] object FeathrUtils {
*/
val SEQ_JOIN_ARRAY_EXPLODE_ENABLED = "seq.join.array.explode.enabled"
val ENABLE_SALTED_JOIN = "enable.salted.join"
val SKIP_MISSING_FEATURE = "skip.missing.feature"
val SALTED_JOIN_FREQ_ITEM_THRESHOLD = "salted.join.freq.item.threshold"
val SALTED_JOIN_FREQ_ITEM_ESTIMATOR = "salted.join.freq.item.estimator"
val SALTED_JOIN_PERSIST = "salted.join.persist"
Expand All @@ -45,9 +46,10 @@ private[offline] object FeathrUtils {
CHECKPOINT_OUTPUT_PATH -> "/tmp/feathr/checkpoints",
ENABLE_CHECKPOINT -> "false",
DEBUG_OUTPUT_PART_NUM -> "200",
FAIL_ON_MISSING_PARTITION -> "false",
FAIL_ON_MISSING_PARTITION -> "true",
SEQ_JOIN_ARRAY_EXPLODE_ENABLED -> "true",
ENABLE_SALTED_JOIN -> "false",
SKIP_MISSING_FEATURE -> "true",
// If one key appears more than 0.02% in the dataset, we will salt this join key and split them into multiple partitions
// This is an empirical value
SALTED_JOIN_FREQ_ITEM_THRESHOLD -> "0.0002",
Expand Down
Loading