From ffbce661d0f8dedc26732f59e6147a71ade2b453 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 23 Jan 2015 10:48:58 -0800 Subject: [PATCH] Fixed compilation error. --- .../org/apache/spark/sql/dsl/package.scala | 18 ++++++++++++++- .../sql/sources/DataSourceStrategy.scala | 22 ++++++++++--------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala index 283ad009db5ab..29c3d26ae56d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala @@ -98,12 +98,28 @@ package object dsl { val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") println(s""" /** - * Register a Scala closure of ${x} arguments as user-defined function (UDF). + * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. */ def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = { ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf)) }""") } + + (0 to 22).map { x => + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val fTypes = Seq.fill(x + 1)("_").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, returnType, Seq($argsInUdf)) + }""") + } + } */ /** * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 37853d4d03019..7d57196b03c2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -113,22 +113,24 @@ private[sql] object DataSourceStrategy extends Strategy { } protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v) - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v) - case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) => GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) => + LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) => + GreaterThanOrEqual(a.name, v) case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) }