Skip to content

Commit

Permalink
Merge pull request feathr-ai#1 from DSNP-CCCM/change_agg_logic
Browse files Browse the repository at this point in the history
first attempt at adapting feathr to cccm use cases
  • Loading branch information
Tian Zhou authored and GitHub Enterprise committed Aug 22, 2023
2 parents 64df5b6 + 42b4043 commit f4c8124
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class SlidingWindowAggregationBuilder extends SlidingWindowOperationBuild
put(TimeWindowAggregationType.MAX, AggregationType.MAX);
put(TimeWindowAggregationType.SUM, AggregationType.SUM);
put(TimeWindowAggregationType.COUNT, AggregationType.COUNT);
// put(TimeWindowAggregationType.COUNT_DISTINCT, AggregationType.COUNT_DISTINCT);
put(TimeWindowAggregationType.LATEST, AggregationType.LATEST);
put(TimeWindowAggregationType.AVG_POOLING, AggregationType.AVG_POOLING);
put(TimeWindowAggregationType.MAX_POOLING, AggregationType.MAX_POOLING);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
* Enumeration class for Sliding time-window aggregation
*/
public enum TimeWindowAggregationType {
SUM, COUNT, AVG, MAX, MIN, TIMESINCE, LATEST, AVG_POOLING, MAX_POOLING, MIN_POOLING
SUM, COUNT, AVG, MAX, MIN, TIMESINCE, LATEST, AVG_POOLING, MAX_POOLING, MIN_POOLING, COUNT_DISTINCT
}
4 changes: 2 additions & 2 deletions feathr-config/src/main/resources/FeatureDefConfigSchema.json
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@
"$ref":"#/anchor/defaultValue"
},
"aggregation": {
"enum": ["SUM", "COUNT", "MAX", "MIN", "AVG", "LATEST", "AVG_POOLING", "MAX_POOLING", "MIN_POOLING"]
"enum": ["SUM", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "AVG", "LATEST", "AVG_POOLING", "MAX_POOLING", "MIN_POOLING"]
},
"window": {
"$ref":"#/anchor/durationPattern"
Expand Down Expand Up @@ -756,7 +756,7 @@
"$ref": "#/anchor/defExpr"
},
"aggregation": {
"enum": ["SUM", "COUNT", "MAX", "AVG", "AVG_POOLING", "MAX_POOLING", "MIN_POOLING"]
"enum": ["SUM", "COUNT", "COUNT_DISTINCT", "MAX", "AVG", "AVG_POOLING", "MAX_POOLING", "MIN_POOLING"]
},
"windowParameters": {
"type": "object",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ private[offline] class TimeWindowConfigurableAnchorExtractor(@JsonProperty("feat
}
val aggFuncName = featureDef.timeWindowFeatureDefinition.aggregationType.toString
val aggType = AggregationType.withName(aggFuncName)

println(aggType)
println(aggFuncName)
// throw new FeathrConfigException(
// ErrorLabel.FEATHR_USER_ERROR,
// s"heyheyhey ${aggFuncName} ${aggType}")

val colName = getFeatureColumnName(featureName, aggFuncName)

val baseAggCol = if (featureDef.timeWindowFeatureDefinition.groupBy.isDefined) {
Expand All @@ -91,14 +98,15 @@ private[offline] class TimeWindowConfigurableAnchorExtractor(@JsonProperty("feat
case AggregationType.SUM => sum(expr(colName))
case AggregationType.AVG => avg(expr(colName))
case AggregationType.COUNT => count(expr(colName))
case AggregationType.COUNT_DISTINCT => approx_count_distinct(expr(colName))
case AggregationType.MAX_POOLING => first(expr(colName))
case AggregationType.MIN_POOLING => first(expr(colName))
case AggregationType.AVG_POOLING => first(expr(colName))
case AggregationType.LATEST => last(expr(colName), true)
case tp =>
throw new FeathrConfigException(
ErrorLabel.FEATHR_USER_ERROR,
s"AggregationType ${tp} is not supported in aggregateAsColumns of TimeWindowConfigurableAnchorExtractor.")
// case tp =>
// throw new FeathrConfigException(
// ErrorLabel.FEATHR_USER_ERROR,
// s"AggregationType ${tp} is not supported in aggregateAsColumns of TimeWindowConfigurableAnchorExtractor.")
}
aggCol.alias(featureName)
}
Expand Down Expand Up @@ -160,6 +168,7 @@ private[offline] class TimeWindowConfigurableAnchorExtractor(@JsonProperty("feat
case AggregationType.SUM => sum(metricColExpr)
case AggregationType.AVG => avg(metricColExpr)
case AggregationType.COUNT => count(metricColExpr)
case AggregationType.COUNT_DISTINCT => approx_count_distinct(expr(colName))
case tp =>
throw new FeathrConfigException(
ErrorLabel.FEATHR_USER_ERROR,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstitut
import com.linkedin.feathr.offline.transformation.FeatureColumnFormat
import com.linkedin.feathr.offline.transformation.FeatureColumnFormat.{FDS_TENSOR, FeatureColumnFormat, RAW}
import com.linkedin.feathr.swj.{FactData, GroupBySpec, LabelData, LateralViewParams, SlidingWindowFeature, SlidingWindowJoin, WindowSpec}
import com.linkedin.feathr.swj.aggregate.{AggregationSpec, AggregationType, AvgAggregate, AvgPoolingAggregate, CountAggregate, LatestAggregate, MaxAggregate, MaxPoolingAggregate, MinAggregate, MinPoolingAggregate, SumAggregate}
import com.linkedin.feathr.swj.aggregate.{AggregationSpec, AggregationType, AvgAggregate, AvgPoolingAggregate, CountAggregate, CountDistinctAggregate, LatestAggregate, MaxAggregate, MaxPoolingAggregate, MinAggregate, MinPoolingAggregate, SumAggregate}

import org.apache.spark.sql.functions.col
import org.apache.spark.sql.DataFrame

Expand Down Expand Up @@ -77,6 +78,7 @@ object AggregationNodeEvaluator extends NodeEvaluator {
// In feathr's use case, we want to treat the count aggregation as simple count of non-null items.
val rewrittenDef = s"CASE WHEN ${featureDef} IS NOT NULL THEN 1 ELSE 0 END"
new CountAggregate(rewrittenDef)
case AggregationType.COUNT_DISTINCT => new CountDistinctAggregate(featureDef)
case AggregationType.AVG => new AvgAggregate(featureDef) // TODO: deal with avg. of pre-aggregated data
case AggregationType.MAX => new MaxAggregate(featureDef)
case AggregationType.MIN => new MinAggregate(featureDef)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ object AnchorSQLOperator extends TransformationOperator {
val sqlKeyExtractor = new SQLSourceKeyExtractor(keySeq)
val withKeyColumnDF = if (appendKeyColumns) sqlKeyExtractor.appendKeyColumns(inputDf) else inputDf
val withFeaturesDf = createFeatureDF(withKeyColumnDF, transformedCols.keys.toSeq)
val outputJoinKeyColumnNames = getFeatureKeyColumnNames(sqlKeyExtractor, withFeaturesDf)
// val outputJoinKeyColumnNames = getFeatureKeyColumnNames(sqlKeyExtractor, withFeaturesDf)
val outputJoinKeyColumnNames = sqlKeyExtractor.getKeyColumnNames()

// Mark as FDS format if it is the FDSExtract SQL function
featureNameToSqlExpr.filter(ele => ele._2.featureExpr.contains(USER_FACING_MULTI_DIM_FDS_TENSOR_UDF_NAME))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ object AnchorUDFOperator extends TransformationOperator {
// Note that for Spark UDFs we only support SQL keys.
val sqlKeyExtractor = new SQLSourceKeyExtractor(keySeq)
val withKeyColumnDF = if (appendKeyColumns) sqlKeyExtractor.appendKeyColumns(inputDf) else inputDf
val outputJoinKeyColumnNames = getFeatureKeyColumnNames(sqlKeyExtractor, withKeyColumnDF)
// val outputJoinKeyColumnNames = getFeatureKeyColumnNames(sqlKeyExtractor, withKeyColumnDF)
val outputJoinKeyColumnNames = sqlKeyExtractor.getKeyColumnNames()


val tensorizedFeatureColumns = sparkExtractor.getFeatures(inputDf, Map())
val transformedColsAndFormats: Map[(String, Column), FeatureColumnFormat] = extractor match {
Expand Down Expand Up @@ -92,7 +94,8 @@ object AnchorUDFOperator extends TransformationOperator {
// Note that for Spark UDFs we only support SQL keys.
val sqlKeyExtractor = new SQLSourceKeyExtractor(keySeq)
val withKeyColumnDF = if (appendKeyColumns) sqlKeyExtractor.appendKeyColumns(inputDf) else inputDf
val outputJoinKeyColumnNames = getFeatureKeyColumnNames(sqlKeyExtractor, withKeyColumnDF)
// val outputJoinKeyColumnNames = getFeatureKeyColumnNames(sqlKeyExtractor, withKeyColumnDF)
val outputJoinKeyColumnNames = sqlKeyExtractor.getKeyColumnNames()

val transformedDF = sparkExtractor.transform(inputDf)
(transformedDF, outputJoinKeyColumnNames)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,19 @@ object SparkIOUtils {
}
}
if(!dfWritten) {
val num_parts = parameters.get(FeathrUtils.DEBUG_OUTPUT_PART_NUM).getOrElse("10").toInt
val num_parts = parameters.get(FeathrUtils.DEBUG_OUTPUT_PART_NUM).getOrElse("10000").toInt
// Honor the debug output part num config
val coalescedDf = outputDF.coalesce(num_parts)
// val coalescedDf = outputDF.coalesce(num_parts)
val coalescedDf = outputDF
outputLocation match {
case SimplePath(path) => {
val output_format = coalescedDf.sqlContext.getConf("spark.feathr.outputFormat", "avro")
// if the output format is set by spark configurations "spark.feathr.outputFormat"
// we will use that as the job output format; otherwise use avro as default for backward compatibility
if(!outputDF.isEmpty) {
coalescedDf.write.mode(SaveMode.Overwrite).format(output_format).save(path)
}
// if(!outputDF.isEmpty) {
// coalescedDf.write.mode(SaveMode.Overwrite).format(output_format).save(path)
// }
coalescedDf.write.mode(SaveMode.Overwrite).format(output_format).save(path)
}
case _ => outputLocation.writeDf(SparkSession.builder().getOrCreate(), coalescedDf, None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@ private[offline] object FeatureTransformation {
* @return feature key column names
*/
def getFeatureKeyColumnNames(sourceKeyExtractor: SourceKeyExtractor, withKeyColumnDF: DataFrame): Seq[String] = {
if (withKeyColumnDF.head(1).isEmpty) {
sourceKeyExtractor.getKeyColumnNames(None)
} else {
sourceKeyExtractor.getKeyColumnNames(Some(withKeyColumnDF.first()))
}
// if (withKeyColumnDF.head(1).isEmpty) {
// sourceKeyExtractor.getKeyColumnNames(None)
// } else {
// sourceKeyExtractor.getKeyColumnNames(Some(withKeyColumnDF.first()))
// }
// sourceKeyExtractor.getKeyColumnNames(Some(withKeyColumnDF.first()))
sourceKeyExtractor.getKeyColumnNames(None)
}

// get the feature column prefix which will be appended to all feature columns of the dataframe returned by the transformer
Expand Down Expand Up @@ -1457,9 +1459,10 @@ private[offline] object FeatureTransformation {
private[offline] def getStandardizedKeyNames(joinKeySize: Int) = {
Range(0, joinKeySize).map("key" + _)
}
// max number of feature groups that can be calculated at the same time
// max number of feature groups (features from the same source )that can be calculated at the same time
// each group will be a separate spark job
private val MAX_PARALLEL_FEATURE_GROUP = 10
// private val MAX_PARALLEL_FEATURE_GROUP = 10
private val MAX_PARALLEL_FEATURE_GROUP = sys.env.getOrElse("MAX_PARALLEL_FEATURE_GROUP","100").toInt
}

private[offline] case class FeatureTypeInferenceContext(featureTypeAccumulators: Map[String, FeatureTypeAccumulator])
13 changes: 13 additions & 0 deletions feathr_project/feathr/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,19 @@ def _materialize_features_with_config(
if monitoring_config_str:
arguments.append('--monitoring-config')
arguments.append(monitoring_config_str)

print("job_name", self.project_name + '_feathr_feature_materialization_job')
print("main_jar_path", self._FEATHR_JOB_JAR_PATH)
print("python_files", cloud_udf_paths)
print("job_tags", job_tags)
print("arguments", arguments)
print("configuration", execution_configurations)
self.logger.info("job_name", self.project_name + '_feathr_feature_materialization_job')
self.logger.info("main_jar_path", self._FEATHR_JOB_JAR_PATH)
self.logger.info("python_files", cloud_udf_paths)
self.logger.info("job_tags", job_tags)
self.logger.info("arguments", arguments)
self.logger.info("configuration", execution_configurations)
return self.feathr_spark_launcher.submit_feathr_job(
job_name=self.project_name + '_feathr_feature_materialization_job',
main_jar_path=self._FEATHR_JOB_JAR_PATH,
Expand Down
2 changes: 1 addition & 1 deletion registry/data-models/transformation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SlidingWindowAggregationType(Enum):
MAX = "maximum"
MIN = "minium"
AVG = "average"

COUNT_DISTINCT = "count_distinct"

class SlidingWindowEmbeddingAggregationType(Enum):
"""
Expand Down

0 comments on commit f4c8124

Please sign in to comment.