From 4a091232122b51f10521a68de8b1d9eb853b563d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 27 Jan 2016 15:35:31 -0800 Subject: [PATCH 001/107] [SPARK-13045] [SQL] Remove ColumnVector.Struct in favor of ColumnarBatch.Row These two classes became identical as the implementation progressed. Author: Nong Li Closes #10952 from nongli/spark-13045. --- .../execution/vectorized/ColumnVector.java | 104 +----------------- .../execution/vectorized/ColumnarBatch.java | 40 ++++--- .../vectorized/ColumnarBatchSuite.scala | 8 +- 3 files changed, 32 insertions(+), 120 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index c119758d68b36..a0bf8734b6545 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -210,104 +210,6 @@ public Object get(int ordinal, DataType dataType) { } } - /** - * Holder object to return a struct. This object is intended to be reused. - */ - public static final class Struct extends InternalRow { - // The fields that make up this struct. For example, if the struct had 2 int fields, the access - // to it would be: - // int f1 = fields[0].getInt[rowId] - // int f2 = fields[1].getInt[rowId] - public final ColumnVector[] fields; - - @Override - public boolean isNullAt(int fieldIdx) { return fields[fieldIdx].getIsNull(rowId); } - - @Override - public boolean getBoolean(int ordinal) { - throw new NotImplementedException(); - } - - public byte getByte(int fieldIdx) { return fields[fieldIdx].getByte(rowId); } - - @Override - public short getShort(int ordinal) { - throw new NotImplementedException(); - } - - public int getInt(int fieldIdx) { return fields[fieldIdx].getInt(rowId); } - public long getLong(int fieldIdx) { return fields[fieldIdx].getLong(rowId); } - - @Override - public float getFloat(int ordinal) { - throw new NotImplementedException(); - } - - public double getDouble(int fieldIdx) { return fields[fieldIdx].getDouble(rowId); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - Array a = getByteArray(ordinal); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); - } - - @Override - public byte[] getBinary(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - return fields[ordinal].getStruct(rowId); - } - - public Array getArray(int fieldIdx) { return fields[fieldIdx].getArray(rowId); } - - @Override - public MapData getMap(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - throw new NotImplementedException(); - } - - public Array getByteArray(int fieldIdx) { return fields[fieldIdx].getByteArray(rowId); } - public Struct getStruct(int fieldIdx) { return fields[fieldIdx].getStruct(rowId); } - - @Override - public final int numFields() { - return fields.length; - } - - @Override - public InternalRow copy() { - throw new NotImplementedException(); - } - - @Override - public boolean anyNull() { - throw new NotImplementedException(); - } - - protected int rowId; - - protected Struct(ColumnVector[] fields) { - this.fields = fields; - } - } - /** * Returns the data type of this column. */ @@ -494,7 +396,7 @@ public void reset() { /** * Returns a utility object to get structs. */ - public Struct getStruct(int rowId) { + public ColumnarBatch.Row getStruct(int rowId) { resultStruct.rowId = rowId; return resultStruct; } @@ -749,7 +651,7 @@ public final int appendStruct(boolean isNull) { /** * Reusable Struct holder for getStruct(). */ - protected final Struct resultStruct; + protected final ColumnarBatch.Row resultStruct; /** * Sets up the common state and also handles creating the child columns if this is a nested @@ -779,7 +681,7 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); } this.resultArray = null; - this.resultStruct = new Struct(this.childColumns); + this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index d558dae50c227..5a575811fa896 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -86,13 +86,23 @@ public void close() { * performance is lost with this translation. */ public static final class Row extends InternalRow { - private int rowId; + protected int rowId; private final ColumnarBatch parent; private final int fixedLenRowSize; + private final ColumnVector[] columns; + // Ctor used if this is a top level row. private Row(ColumnarBatch parent) { this.parent = parent; this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + this.columns = parent.columns; + } + + // Ctor used if this is a struct. + protected Row(ColumnVector[] columns) { + this.parent = null; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); + this.columns = columns; } /** @@ -103,23 +113,23 @@ public final void markFiltered() { parent.markFiltered(rowId); } + public ColumnVector[] columns() { return columns; } + @Override - public final int numFields() { - return parent.numCols(); - } + public final int numFields() { return columns.length; } @Override /** * Revisit this. This is expensive. */ public final InternalRow copy() { - UnsafeRow row = new UnsafeRow(parent.numCols()); + UnsafeRow row = new UnsafeRow(numFields()); row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize); - for (int i = 0; i < parent.numCols(); i++) { + for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = parent.schema.fields()[i].dataType(); + DataType dt = columns[i].dataType(); if (dt instanceof IntegerType) { row.setInt(i, getInt(i)); } else if (dt instanceof LongType) { @@ -141,7 +151,7 @@ public final boolean anyNull() { @Override public final boolean isNullAt(int ordinal) { - return parent.column(ordinal).getIsNull(rowId); + return columns[ordinal].getIsNull(rowId); } @Override @@ -150,7 +160,7 @@ public final boolean getBoolean(int ordinal) { } @Override - public final byte getByte(int ordinal) { return parent.column(ordinal).getByte(rowId); } + public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } @Override public final short getShort(int ordinal) { @@ -159,11 +169,11 @@ public final short getShort(int ordinal) { @Override public final int getInt(int ordinal) { - return parent.column(ordinal).getInt(rowId); + return columns[ordinal].getInt(rowId); } @Override - public final long getLong(int ordinal) { return parent.column(ordinal).getLong(rowId); } + public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } @Override public final float getFloat(int ordinal) { @@ -172,7 +182,7 @@ public final float getFloat(int ordinal) { @Override public final double getDouble(int ordinal) { - return parent.column(ordinal).getDouble(rowId); + return columns[ordinal].getDouble(rowId); } @Override @@ -182,7 +192,7 @@ public final Decimal getDecimal(int ordinal, int precision, int scale) { @Override public final UTF8String getUTF8String(int ordinal) { - ColumnVector.Array a = parent.column(ordinal).getByteArray(rowId); + ColumnVector.Array a = columns[ordinal].getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } @@ -198,12 +208,12 @@ public final CalendarInterval getInterval(int ordinal) { @Override public final InternalRow getStruct(int ordinal, int numFields) { - return parent.column(ordinal).getStruct(rowId); + return columns[ordinal].getStruct(rowId); } @Override public final ArrayData getArray(int ordinal) { - return parent.column(ordinal).getArray(rowId); + return columns[ordinal].getArray(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 215ca9ab6b770..67cc08b6fc8ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -439,10 +439,10 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(1, 5.67) val s = column.getStruct(0) - assert(s.fields(0).getInt(0) == 123) - assert(s.fields(0).getInt(1) == 456) - assert(s.fields(1).getDouble(0) == 3.45) - assert(s.fields(1).getDouble(1) == 5.67) + assert(s.columns()(0).getInt(0) == 123) + assert(s.columns()(0).getInt(1) == 456) + assert(s.columns()(1).getDouble(0) == 3.45) + assert(s.columns()(1).getDouble(1) == 5.67) assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) From c2204436a15838f2dce44e3cfb0fe58236ef6196 Mon Sep 17 00:00:00 2001 From: James Lohse Date: Thu, 28 Jan 2016 10:50:50 +0000 Subject: [PATCH 002/107] Provide same info as in spark-submit --help this is stated for --packages and --repositories. Without stating it for --jars, people expect a standard java classpath to work, with expansion and using a different delimiter than a comma. Currently this is only state in the --help for spark-submit "Comma-separated list of local jars to include on the driver and executor classpaths." Author: James Lohse Closes #10890 from jimlohse/patch-1. --- docs/submitting-applications.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index acbb0f298fe47..413532f2f6cfa 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -177,8 +177,9 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. Spark uses the following URL scheme to allow -different strategies for disseminating jars: +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. + +Spark uses the following URL scheme to allow different strategies for disseminating jars: - **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and every executor pulls the file from the driver HTTP server. From 415d0a859b7a76f3a866ec62ab472c4050f2a01b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 28 Jan 2016 12:26:03 -0800 Subject: [PATCH 003/107] [SPARK-12818][SQL] Specialized integral and string types for Count-min Sketch This PR is a follow-up of #10911. It adds specialized update methods for `CountMinSketch` so that we can avoid doing internal/external row format conversion in `DataFrame.countMinSketch()`. Author: Cheng Lian Closes #10968 from liancheng/cms-specialized. --- .../spark/util/sketch/CountMinSketch.java | 34 +++++++++- .../spark/util/sketch/CountMinSketchImpl.java | 35 ++++++++-- .../spark/sql/DataFrameStatFunctions.scala | 65 +++++++++++-------- 3 files changed, 99 insertions(+), 35 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 5692e574d4c7e..f0aac5bb00dfb 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -115,15 +115,45 @@ int getVersionNumber() { public abstract long totalCount(); /** - * Adds 1 to {@code item}. + * Increments {@code item}'s count by one. */ public abstract void add(Object item); /** - * Adds {@code count} to {@code item}. + * Increments {@code item}'s count by {@code count}. */ public abstract void add(Object item, long count); + /** + * Increments {@code item}'s count by one. + */ + public abstract void addLong(long item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addLong(long item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addString(String item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addString(String item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addBinary(byte[] item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addBinary(byte[] item, long count); + /** * Returns the estimated frequency of {@code item}. */ diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index e49ae22906c4c..c0631c6778df4 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -25,7 +25,6 @@ import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.Serializable; -import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; @@ -146,27 +145,49 @@ public void add(Object item, long count) { } } - private void addString(String item, long count) { + @Override + public void addString(String item) { + addString(item, 1); + } + + @Override + public void addString(String item, long count) { + addBinary(Utils.getBytesFromUTF8String(item), count); + } + + @Override + public void addLong(long item) { + addLong(item, 1); + } + + @Override + public void addLong(long item, long count) { if (count < 0) { throw new IllegalArgumentException("Negative increments not implemented"); } - int[] buckets = getHashBuckets(item, depth, width); - for (int i = 0; i < depth; ++i) { - table[i][buckets[i]] += count; + table[i][hash(item, i)] += count; } totalCount += count; } - private void addLong(long item, long count) { + @Override + public void addBinary(byte[] item) { + addBinary(item, 1); + } + + @Override + public void addBinary(byte[] item, long count) { if (count < 0) { throw new IllegalArgumentException("Negative increments not implemented"); } + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { - table[i][hash(item, i)] += count; + table[i][buckets[i]] += count; } totalCount += count; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b0b6995a2214f..bb3cc02800d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** @@ -109,7 +109,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Null elements will be replaced by "null", and back ticks will be dropped from elements if they * exist. * - * * @param col1 The name of the first column. Distinct items will make the first item of * each row. * @param col2 The name of the second column. Distinct items will make the column names @@ -374,21 +373,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { val singleCol = df.select(col) val colType = singleCol.schema.head.dataType - require( - colType == StringType || colType.isInstanceOf[IntegralType], - s"Count-min Sketch only supports string type and integral types, " + - s"and does not support type $colType." - ) + val updater: (CountMinSketch, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` + // instead of `addString` to avoid unnecessary conversion. + case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes) + case ByteType => (sketch, row) => sketch.addLong(row.getByte(0)) + case ShortType => (sketch, row) => sketch.addLong(row.getShort(0)) + case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0)) + case LongType => (sketch, row) => sketch.addLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Count-min Sketch only supports string type and integral types, " + + s"and does not support type $colType." + ) + } - singleCol.rdd.aggregate(zero)( - (sketch: CountMinSketch, row: Row) => { - sketch.add(row.get(0)) + singleCol.queryExecution.toRdd.aggregate(zero)( + (sketch: CountMinSketch, row: InternalRow) => { + updater(sketch, row) sketch }, - - (sketch1: CountMinSketch, sketch2: CountMinSketch) => { - sketch1.mergeInPlace(sketch2) - } + (sketch1, sketch2) => sketch1.mergeInPlace(sketch2) ) } @@ -447,19 +452,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { require(colType == StringType || colType.isInstanceOf[IntegralType], s"Bloom filter only supports string type and integral types, but got $colType.") - val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { - (filter, row) => - // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` - // instead of `putString` to avoid unnecessary conversion. - filter.putBinary(row.getUTF8String(0).getBytes) - filter - } else { - (filter, row) => - // TODO: specialize it. - filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) - filter + val updater: (BloomFilter, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes) + case ByteType => (filter, row) => filter.putLong(row.getByte(0)) + case ShortType => (filter, row) => filter.putLong(row.getShort(0)) + case IntegerType => (filter, row) => filter.putLong(row.getInt(0)) + case LongType => (filter, row) => filter.putLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Bloom filter only supports string type and integral types, " + + s"and does not support type $colType." + ) } - singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + singleCol.queryExecution.toRdd.aggregate(zero)( + (filter: BloomFilter, row: InternalRow) => { + updater(filter, row) + filter + }, + (filter1, filter2) => filter1.mergeInPlace(filter2) + ) } } From 676803963fcc08aa988aa6f14be3751314e006ca Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 28 Jan 2016 13:45:28 -0800 Subject: [PATCH 004/107] [SPARK-12926][SQL] SQLContext to display warning message when non-sql configs are being set Users unknowingly try to set core Spark configs in SQLContext but later realise that it didn't work. eg. sqlContext.sql("SET spark.shuffle.memoryFraction=0.4"). This PR adds a warning message when such operations are done. Author: Tejas Patil Closes #10849 from tejasapatil/SPARK-12926. --- .../main/scala/org/apache/spark/sql/SQLConf.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index c9ba6700998c3..eb9da0bd4fd4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.util.Utils @@ -519,7 +520,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf { +private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -628,7 +629,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon // Only verify configs in the SQLConf object entry.valueConverter(value) } - settings.put(key, value) + setConfWithCheck(key, value) } /** Set the given Spark SQL configuration property. */ @@ -636,7 +637,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon require(entry != null, "entry cannot be null") require(value != null, s"value cannot be null for key: ${entry.key}") require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - settings.put(entry.key, entry.stringConverter(value)) + setConfWithCheck(entry.key, entry.stringConverter(value)) } /** Return the value of Spark SQL configuration property for the given key. */ @@ -699,6 +700,13 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon }.toSeq } + private def setConfWithCheck(key: String, value: String): Unit = { + if (key.startsWith("spark.") && !key.startsWith("spark.sql.")) { + logWarning(s"Attempt to set non-Spark SQL config in SQLConf: key = $key, value = $value") + } + settings.put(key, value) + } + private[spark] def unsetConf(key: String): Unit = { settings.remove(key) } From cc18a7199240bf3b03410c1ba6704fe7ce6ae38e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 28 Jan 2016 13:51:55 -0800 Subject: [PATCH 005/107] [SPARK-13031] [SQL] cleanup codegen and improve test coverage 1. enable whole stage codegen during tests even there is only one operator supports that. 2. split doProduce() into two APIs: upstream() and doProduce() 3. generate prefix for fresh names of each operator 4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again) 5. fix bugs and tests. Author: Davies Liu Closes #10944 from davies/gen_refactor. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++++++++------ .../aggregate/TungstenAggregate.scala | 88 +++++--- .../spark/sql/execution/basicOperators.scala | 96 +++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 9 files changed, 334 insertions(+), 202 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2747c315ad374..e6704cf8bb1f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,14 +144,23 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * A prefix used to generate fresh name. + */ + var freshNamePrefix = "" + /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" + def freshName(name: String): String = { + if (freshNamePrefix == "") { + s"$name${curId.getAndIncrement}" + } else { + s"${freshNamePrefix}_$name${curId.getAndIncrement}" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d9fe76133c6ef..ec31db19b94b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 57f4945de9804..ef81ba60f049f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns an input RDD of InternalRow and Java source code to process them. + * Returns the RDD of InternalRow which generates the input rows. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + def upstream(): RDD[InternalRow] + + /** + * Returns Java source code to process the rows from upstream. + */ + def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent + ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + protected def doProduce(ctx: CodegenContext): String /** - * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + * Consume the columns generated from current SparkPlan, call it's parent. */ - protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { - assert(columns.length == output.length) - parent.doConsume(ctx, this, columns) + def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + if (input != null) { + assert(input.length == output.length) + } + parent.consumeChild(ctx, this, input, row) } + /** + * Consume the columns generated from it's child, call doConsume() or emit the rows. + */ + def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + ctx.freshNamePrefix = nodeName + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + val evals = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + s""" + | ${evals.map(_.code).mkString("\n")} + | ${doConsume(ctx, evals)} + """.stripMargin + } else { + doConsume(ctx, input) + } + } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } } @@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doPrepare(): Unit = { + child.prepare() + } - override def supportCodegen: Boolean = true + override def doExecute(): RDD[InternalRow] = { + child.execute() + } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def supportCodegen: Boolean = false + + override def upstream(): RDD[InternalRow] = { + child.execute() + } + + override def doProduce(ctx: CodegenContext): String = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - val code = s""" - | while (input.hasNext()) { + s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin - (child.execute(), code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } - - override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() --------> produce() + * doExecute() ---------> upstream() -------> upstream() ------> execute() + * | + * -----------------> produce() * | * doProduce() -------> produce() * | - * doProduce() ---> execute() + * doProduce() * | * consume() - * doConsume() ------------| + * consumeChild() <-----------| * | - * doConsume() <----- consume() + * doConsume() + * | + * consumeChild() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { + override def supportCodegen: Boolean = false + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + + override def doPrepare(): Unit = { + plan.prepare() + } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val (rdd, code) = plan.produce(ctx, this) + val code = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() { + protected void processNext() throws java.io.IOException { $code - } + } } - """ + """ + // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - rdd.mapPartitions { iter => + plan.upstream().mapPartitions { iter => + val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { throw new UnsupportedOperationException } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException + } + + override def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + + if (row != null) { + // There is an UnsafeRow already s""" - | currentRow = unsafeRow; + | currentRow = $row; | return; """.stripMargin + } else { + assert(input != null) + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin + } } } @@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - plan.children.exists(supportCodegen) => + (Utils.isTesting || plan.children.exists(supportCodegen)) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - inputs += p - InputAdapter(p) + val input = apply(p) // collapse them recursively + inputs += input + InputAdapter(input) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 23e54f344d252..cbd2634b8900f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,9 +117,7 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && - // final aggregation only have one row, do not need to codegen - !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } // The variables used as aggregation buffer @@ -127,7 +125,11 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -137,50 +139,80 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | boolean $isNull = ${ev.isNull}; - | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val source = + // generate variables for output + val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val bufferAttrs = functions.flatMap(_.aggBufferAttributes) + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttrs).gen(ctx) + } + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + } + (resultVars, s""" + | ${aggResults.map(_.code).mkString("\n")} + | ${resultVars.map(_.code).mkString("\n")} + """.stripMargin) + } else { + // output the aggregate buffer directly + (bufVars, "") + } + + val doAgg = ctx.freshName("doAgg") + ctx.addNewFunction(doAgg, s""" - | if (!$initAgg) { - | $initAgg = true; - | + | private void $doAgg() { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | $childSource - | - | // output the result - | ${consume(ctx, bufVars)} + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } - """.stripMargin + """.stripMargin) - (rdd, source) + s""" + | if (!$initAgg) { + | $initAgg = true; + | $doAgg(); + | + | // output the result + | $genResult + | + | ${consume(ctx, resultVars)} + | } + """.stripMargin } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - // the mode could be only Partial or PartialMerge - val updateExpr = if (modes.contains(Partial)) { - functions.flatMap(_.updateExpressions) - } else { - functions.flatMap(_.mergeExpressions) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } } - val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output - val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val codes = boundExpr.zipWithIndex.map { case (e, i) => - val ev = e.gen(ctx) + val updates = updateExpr.zipWithIndex.map { case (e, i) => + val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -190,7 +222,7 @@ case class TungstenAggregate( s""" | // do aggregate and update aggregation buffer - | ${codes.mkString("")} + | ${updates.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6deb72adad5ec..e7a73d5fbb4bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -153,17 +161,21 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { - val initTerm = ctx.freshName("range_initRange") + override def upstream(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + + protected override def doProduce(ctx: CodegenContext): String = { + val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("range_partitionEnd") + val partitionEnd = ctx.freshName("partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("range_number") + val number = ctx.freshName("number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("range_overflow") + val overflow = ctx.freshName("overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("range_value") + val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -172,38 +184,42 @@ case class Range( s"$number > $partitionEnd" } - val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) + ctx.addNewFunction("initRange", + s""" + | private void initRange(int idx) { + | $BigInt index = $BigInt.valueOf(idx); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | } + """.stripMargin) - val code = s""" + s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } + | initRange(((InternalRow) input.next()).getInt(0)); | } else { | return; | } @@ -218,12 +234,6 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin - - (rdd, code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 989cb2942918e..51a50c1fa30e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) + // TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } - - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cbae19ebd269d..82f6811503c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d48143762cac0..7d6bff8295d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 9a24a2487a254..a3e5243b68aba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + } assert(metrics.length == 3) assert(metrics(0) == 1) From df78a934a07a4ce5af43243be9ba5fe60b91eee6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 28 Jan 2016 14:29:47 -0800 Subject: [PATCH 006/107] [SPARK-9835][ML] Implement IterativelyReweightedLeastSquares solver Implement ```IterativelyReweightedLeastSquares``` solver for GLM. I consider it as a solver rather than estimator, it only used internal so I keep it ```private[ml]```. There are two limitations in the current implementation compared with R: * It can not support ```Tuple``` as response for ```Binomial``` family, such as the following code: ``` glm( cbind(using, notUsing) ~ age + education + wantsMore , family = binomial) ``` * It does not support ```offset```. Because I considered that ```RFormula``` did not support ```Tuple``` as label and ```offset``` keyword, so I simplified the implementation. But to add support for these two functions is not very hard, I can do it in follow-up PR if it is necessary. Meanwhile, we can also add R-like statistic summary for IRLS. The implementation refers R, [statsmodels](https://github.com/statsmodels/statsmodels) and [sparkGLM](https://github.com/AlteryxLabs/sparkGLM). Please focus on the main structure and overpass minor issues/docs that I will update later. Any comments and opinions will be appreciated. cc mengxr jkbradley Author: Yanbo Liang Closes #10639 from yanboliang/spark-9835. --- .../IterativelyReweightedLeastSquares.scala | 108 ++++++++++ .../spark/ml/optim/WeightedLeastSquares.scala | 7 +- ...erativelyReweightedLeastSquaresSuite.scala | 200 ++++++++++++++++++ 3 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala new file mode 100644 index 0000000000000..6aa44e6ba723e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.Logging +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg._ +import org.apache.spark.rdd.RDD + +/** + * Model fitted by [[IterativelyReweightedLeastSquares]]. + * @param coefficients model coefficients + * @param intercept model intercept + */ +private[ml] class IterativelyReweightedLeastSquaresModel( + val coefficients: DenseVector, + val intercept: Double) extends Serializable + +/** + * Implements the method of iteratively reweighted least squares (IRLS) which is used to solve + * certain optimization problems by an iterative method. In each step of the iterations, it + * involves solving a weighted lease squares (WLS) problem by [[WeightedLeastSquares]]. + * It can be used to find maximum likelihood estimates of a generalized linear model (GLM), + * find M-estimator in robust regression and other optimization problems. + * + * @param initialModel the initial guess model. + * @param reweightFunc the reweight function which is used to update offsets and weights + * at each iteration. + * @param fitIntercept whether to fit intercept. + * @param regParam L2 regularization parameter used by WLS. + * @param maxIter maximum number of iterations. + * @param tol the convergence tolerance. + * + * @see [[http://www.jstor.org/stable/2345503 P. J. Green, Iteratively Reweighted Least Squares + * for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives, + * Journal of the Royal Statistical Society. Series B, 1984.]] + */ +private[ml] class IterativelyReweightedLeastSquares( + val initialModel: WeightedLeastSquaresModel, + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double), + val fitIntercept: Boolean, + val regParam: Double, + val maxIter: Int, + val tol: Double) extends Logging with Serializable { + + def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = { + + var converged = false + var iter = 0 + + var model: WeightedLeastSquaresModel = initialModel + var oldModel: WeightedLeastSquaresModel = null + + while (iter < maxIter && !converged) { + + oldModel = model + + // Update offsets and weights using reweightFunc + val newInstances = instances.map { instance => + val (newOffset, newWeight) = reweightFunc(instance, oldModel) + Instance(newOffset, newWeight, instance.features) + } + + // Estimate new model + model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false, + standardizeLabel = false).fit(newInstances) + + // Check convergence + val oldCoefficients = oldModel.coefficients + val coefficients = model.coefficients + BLAS.axpy(-1.0, coefficients, oldCoefficients) + val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) => + math.max(math.abs(x), math.abs(y)) + } + val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept)) + + if (maxTol < tol) { + converged = true + logInfo(s"IRLS converged in $iter iterations.") + } + + logInfo(s"Iteration $iter : relative tolerance = $maxTol") + iter = iter + 1 + + if (iter == maxIter) { + logInfo(s"IRLS reached the max number of iterations: $maxIter.") + } + + } + + new IterativelyReweightedLeastSquaresModel(model.coefficients, model.intercept) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 797870eb8ce8a..61b3642131810 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -31,7 +31,12 @@ import org.apache.spark.rdd.RDD private[ml] class WeightedLeastSquaresModel( val coefficients: DenseVector, val intercept: Double, - val diagInvAtWA: DenseVector) extends Serializable + val diagInvAtWA: DenseVector) extends Serializable { + + def predict(features: Vector): Double = { + BLAS.dot(coefficients, features) + intercept + } +} /** * Weighted least squares solver via normal equation. diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala new file mode 100644 index 0000000000000..604021220a139 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { + + private var instances1: RDD[Instance] = _ + private var instances2: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) + b <- c(1, 0, 1, 0) + w <- c(1, 2, 3, 4) + */ + instances1 = sc.parallelize(Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + ), 2) + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + */ + instances2 = sc.parallelize(Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) + } + + test("IRLS against GLM with Binomial errors") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- glm(formula, family="binomial", data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -0.30216651 -0.04452045 + [1] 3.5651651 -1.2334085 -0.7348971 + */ + val expected = Seq( + Vectors.dense(0.0, -0.30216651, -0.04452045), + Vectors.dense(3.5651651, -1.2334085, -0.7348971)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val newInstances = instances1.map { instance => + val mu = (instance.label + 0.5) / 2.0 + val eta = math.log(mu / (1.0 - mu)) + Instance(eta, instance.weight, instance.features) + } + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + val irls = new IterativelyReweightedLeastSquares(initial, BinomialReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances1) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("IRLS against GLM with Poisson errors") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- glm(formula, family="poisson", data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -0.09607792 0.18375613 + [1] 6.299947 3.324107 -1.081766 + */ + val expected = Seq( + Vectors.dense(0.0, -0.09607792, 0.18375613), + Vectors.dense(6.299947, 3.324107, -1.081766)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val yMean = instances2.map(_.label).mean + val newInstances = instances2.map { instance => + val mu = (instance.label + yMean) / 2.0 + val eta = math.log(mu) + Instance(eta, instance.weight, instance.features) + } + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + val irls = new IterativelyReweightedLeastSquares(initial, PoissonReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances2) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("IRLS against L1Regression") { + /* + R code: + + library(quantreg) + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- rq(formula, data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] 1.266667 0.400000 + [1] 29.5 17.0 -5.5 + */ + val expected = Seq( + Vectors.dense(0.0, 1.266667, 0.400000), + Vectors.dense(29.5, 17.0, -5.5)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(instances2) + val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } +} + +object IterativelyReweightedLeastSquaresSuite { + + def BinomialReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val mu = 1.0 / (1.0 + math.exp(-1.0 * eta)) + val z = eta + (instance.label - mu) / (mu * (1.0 - mu)) + val w = mu * (1 - mu) * instance.weight + (z, w) + } + + def PoissonReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val mu = math.exp(eta) + val z = eta + (instance.label - mu) / mu + val w = mu * instance.weight + (z, w) + } + + def L1RegressionReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val e = math.max(math.abs(eta - instance.label), 1e-7) + val w = 1 / e + val y = instance.label + (y, w) + } +} From abae889f08eb412cb897e4e63614ec2c93885ffd Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 28 Jan 2016 15:20:16 -0800 Subject: [PATCH 007/107] [SPARK-12401][SQL] Add integration tests for postgres enum types We can handle posgresql-specific enum types as strings in jdbc. So, we should just add tests and close the corresponding JIRA ticket. Author: Takeshi YAMAMURO Closes #10596 from maropu/AddTestsInIntegration. --- .../spark/sql/jdbc/PostgresIntegrationSuite.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 7d011be37067b..72bda8fe1ef10 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.Connection import java.util.Properties import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{If, Literal} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.tags.DockerTest @DockerTest @@ -39,12 +39,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE DATABASE foo").executeUpdate() conn.setCatalog("foo") + conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate() conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " - + "c10 integer[], c11 text[], c12 real[])").executeUpdate() + + "c10 integer[], c11 text[], c12 real[], c13 enum_type)").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " - + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate() + + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', 'd1')""").executeUpdate() } test("Type mapping for various types") { @@ -52,7 +53,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass) - assert(types.length == 13) + assert(types.length == 14) assert(classOf[String].isAssignableFrom(types(0))) assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) assert(classOf[java.lang.Double].isAssignableFrom(types(2))) @@ -66,22 +67,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(classOf[Seq[Int]].isAssignableFrom(types(10))) assert(classOf[Seq[String]].isAssignableFrom(types(11))) assert(classOf[Seq[Double]].isAssignableFrom(types(12))) + assert(classOf[String].isAssignableFrom(types(13))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) assert(rows(0).getLong(3) == 123456789012345L) - assert(rows(0).getBoolean(4) == false) + assert(!rows(0).getBoolean(4)) // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) - assert(rows(0).getBoolean(7) == true) + assert(rows(0).getBoolean(7)) assert(rows(0).getString(8) == "172.16.0.42") assert(rows(0).getString(9) == "192.168.0.0/16") assert(rows(0).getSeq(10) == Seq(1, 2)) assert(rows(0).getSeq(11) == Seq("a", null, "b")) assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) + assert(rows(0).getString(13) == "d1") } test("Basic write test") { From 3a40c0e575fd4215302ea60c9821d31a5a138b8a Mon Sep 17 00:00:00 2001 From: Brandon Bradley Date: Thu, 28 Jan 2016 15:25:57 -0800 Subject: [PATCH 008/107] [SPARK-12749][SQL] add json option to parse floating-point types as DecimalType I tried to add this via `USE_BIG_DECIMAL_FOR_FLOATS` option from Jackson with no success. Added test for non-complex types. Should I add a test for complex types? Author: Brandon Bradley Closes #10936 from blbradley/spark-12749. --- python/pyspark/sql/readwriter.py | 2 ++ .../apache/spark/sql/DataFrameReader.scala | 2 ++ .../datasources/json/InferSchema.scala | 8 ++++-- .../datasources/json/JSONOptions.scala | 2 ++ .../datasources/json/JsonSuite.scala | 28 +++++++++++++++++++ 5 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 0b20022b14b8d..b1453c637f79e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -152,6 +152,8 @@ def json(self, path, schema=None): You can set the following JSON-specific options to deal with non-standard JSON files: * ``primitivesAsString`` (default ``false``): infers all primitive values as a string \ type + * `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal \ + type * ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records * ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names * ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 634c1bd4739b1..2e0c6c7df967e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -252,6 +252,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * You can set the following JSON-specific options to deal with non-standard JSON files: *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal + * type
  • *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 44d5e4ff7ec8b..8b773ddfcb656 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -134,8 +134,12 @@ private[json] object InferSchema { val v = parser.getDecimalValue DecimalType(v.precision(), v.scale()) case FLOAT | DOUBLE => - // TODO(davies): Should we use decimal if possible? - DoubleType + if (configOptions.floatAsBigDecimal) { + val v = parser.getDecimalValue + DecimalType(v.precision(), v.scale()) + } else { + DoubleType + } } case VALUE_TRUE | VALUE_FALSE => BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index fe5b20697e40e..31a95ed461215 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -34,6 +34,8 @@ private[sql] class JSONOptions( parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + val floatAsBigDecimal = + parameters.get("floatAsBigDecimal").map(_.toBoolean).getOrElse(false) val allowComments = parameters.get("allowComments").map(_.toBoolean).getOrElse(false) val allowUnquotedFieldNames = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 00eaeb0d34e87..dd83a0e36f6f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -771,6 +771,34 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Loading a JSON dataset floatAsBigDecimal returns schema with float types as BigDecimal") { + val jsonDF = sqlContext.read.option("floatAsBigDecimal", "true").json(primitiveFieldAndType) + + val expectedSchema = StructType( + StructField("bigInteger", DecimalType(20, 0), true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DecimalType(17, -292), true) :: + StructField("integer", LongType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row(BigDecimal("92233720368547758070"), + true, + BigDecimal("1.7976931348623157E308"), + 10, + 21474836470L, + null, + "this is a simple string.") + ) + } + test("Loading a JSON dataset from a text file with SQL") { val dir = Utils.createTempDir() dir.delete() From 4637fc08a3733ec313218fb7e4d05064d9a6262d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Jan 2016 16:25:21 -0800 Subject: [PATCH 009/107] [SPARK-11955][SQL] Mark optional fields in merging schema for safely pushdowning filters in Parquet JIRA: https://issues.apache.org/jira/browse/SPARK-11955 Currently we simply skip pushdowning filters in parquet if we enable schema merging. However, we can actually mark particular fields in merging schema for safely pushdowning filters in parquet. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #9940 from viirya/safe-pushdown-parquet-filters. --- .../org/apache/spark/sql/types/Metadata.scala | 5 +++ .../apache/spark/sql/types/StructType.scala | 34 ++++++++++++--- .../spark/sql/types/DataTypeSuite.scala | 14 ++++-- .../datasources/parquet/ParquetFilters.scala | 37 +++++++++++----- .../datasources/parquet/ParquetRelation.scala | 13 +++--- .../parquet/ParquetFilterSuite.scala | 43 ++++++++++++++++++- 6 files changed, 117 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 9e0f9943bc638..66f123682e117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -273,4 +273,9 @@ class MetadataBuilder { map.put(key, value) this } + + def remove(key: String): this.type = { + map.remove(key) + this + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3bd733fa2d26c..da0c92864e9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -334,6 +334,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru object StructType extends AbstractDataType { + private[sql] val metadataKeyForOptionalField = "_OPTIONAL_" + override private[sql] def defaultConcreteType: DataType = new StructType override private[sql] def acceptsType(other: DataType): Boolean = { @@ -359,6 +361,18 @@ object StructType extends AbstractDataType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + def removeMetadata(key: String, dt: DataType): DataType = + dt match { + case StructType(fields) => + val newFields = fields.map { f => + val mb = new MetadataBuilder() + f.copy(dataType = removeMetadata(key, f.dataType), + metadata = mb.withMetadata(f.metadata).remove(key).build()) + } + StructType(newFields) + case _ => dt + } + private[sql] def merge(left: DataType, right: DataType): DataType = (left, right) match { case (ArrayType(leftElementType, leftContainsNull), @@ -376,24 +390,32 @@ object StructType extends AbstractDataType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + // This metadata will record the fields that only exist in one of two StructTypes + val optionalMeta = new MetadataBuilder() val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } - .orElse(Some(leftField)) + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse { + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + Some(leftField.copy(metadata = optionalMeta.build())) + } .foreach(newFields += _) } val leftMapped = fieldsMap(leftFields) rightFields .filterNot(f => leftMapped.get(f.name).nonEmpty) - .foreach(newFields += _) + .foreach { f => + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + newFields += f.copy(metadata = optionalMeta.build()) + } StructType(newFields) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 706ecd29d1355..c2bbca7c33f28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -122,7 +122,9 @@ class DataTypeSuite extends SparkFunSuite { val right = StructType(List()) val merged = left.merge(right) - assert(merged === left) + assert(DataType.equalsIgnoreCompatibleNullability(merged, left)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where left is empty") { @@ -135,8 +137,9 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(right === merged) - + assert(DataType.equalsIgnoreCompatibleNullability(merged, right)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where both are non-empty") { @@ -154,7 +157,10 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(merged === expected) + assert(DataType.equalsIgnoreCompatibleNullability(merged, expected)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where right contains type conflict") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index e9b734b0abf50..5a5cb5cf03d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -207,11 +207,26 @@ private[sql] object ParquetFilters { */ } + /** + * SPARK-11955: The optional fields will have metadata StructType.metadataKeyForOptionalField. + * These fields only exist in one side of merged schemas. Due to that, we can't push down filters + * using such fields, otherwise Parquet library will throw exception. Here we filter out such + * fields. + */ + private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match { + case StructType(fields) => + fields.filter { f => + !f.metadata.contains(StructType.metadataKeyForOptionalField) || + !f.metadata.getBoolean(StructType.metadataKeyForOptionalField) + }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) } + case _ => Array.empty[(String, DataType)] + } + /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { - val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + val dataTypeOf = getFieldMap(schema).toMap relaxParquetValidTypeMap @@ -231,29 +246,29 @@ private[sql] object ParquetFilters { // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) => + case sources.IsNull(name) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.IsNotNull(name) => + case sources.IsNotNull(name) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.EqualTo(name, value) => + case sources.EqualTo(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) => + case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) => + case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) => + case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThan(name, value) => + case sources.LessThan(name, value) if dataTypeOf.contains(name) => makeLt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) => + case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) => makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThan(name, value) => + case sources.GreaterThan(name, value) if dataTypeOf.contains(name) => makeGt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) => + case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) => makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.In(name, valueSet) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b460ec1d26047..f87590095d344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -258,7 +258,12 @@ private[sql] class ParquetRelation( job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport]) - CatalystWriteSupport.setSchema(dataSchema, conf) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushdowning filters. + val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, + dataSchema).asInstanceOf[StructType] + CatalystWriteSupport.setSchema(dataSchemaToWrite, conf) // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) // and `CatalystWriteSupport` (writing actual rows to Parquet files). @@ -304,10 +309,6 @@ private[sql] class ParquetRelation( val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - // When merging schemas is enabled and the column of the given filter does not exist, - // Parquet emits an exception which is an issue of Parquet (PARQUET-389). - val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown - // Parquet row group size. We will use this value as the value for // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value // of these flags are smaller than the parquet row group size. @@ -321,7 +322,7 @@ private[sql] class ParquetRelation( dataSchema, parquetBlockSize, useMetadataCache, - safeParquetFilterPushDown, + parquetFilterPushDown, assumeBinaryIsString, assumeInt96IsTimestamp) _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 97c5313f0feff..1796b3af0e37a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -379,9 +380,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. This is a Parquet issue (PARQUET-389). + val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") checkAnswer( - sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"), - (1 to 1).map(i => Row(i, i.toString, null))) + df, + Row(1, "1", null)) + + // The fields "a" and "c" only exist in one Parquet file. + assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathThree = s"${dir.getCanonicalPath}/table3" + df.write.parquet(pathThree) + + // We will remove the temporary metadata when writing Parquet file. + val schema = sqlContext.read.parquet(pathThree).schema + assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + val pathFour = s"${dir.getCanonicalPath}/table4" + val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + dfStruct.select(struct("a").as("s")).write.parquet(pathFour) + + val pathFive = s"${dir.getCanonicalPath}/table5" + val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") + dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) + + // If the "s.c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. + val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1") + .selectExpr("s") + checkAnswer(dfStruct3, Row(Row(null, 1))) + + // The fields "s.a" and "s.c" only exist in one Parquet file. + val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] + assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathSix = s"${dir.getCanonicalPath}/table6" + dfStruct3.write.parquet(pathSix) + + // We will remove the temporary metadata when writing Parquet file. + val forPathSix = sqlContext.read.parquet(pathSix).schema + assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) } } } From b9dfdcc63bb12bc24de96060e756889c2ceda519 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 28 Jan 2016 17:01:12 -0800 Subject: [PATCH 010/107] Revert "[SPARK-13031] [SQL] cleanup codegen and improve test coverage" This reverts commit cc18a7199240bf3b03410c1ba6704fe7ce6ae38e. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++------------ .../aggregate/TungstenAggregate.scala | 88 +++----- .../spark/sql/execution/basicOperators.scala | 96 ++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 9 files changed, 202 insertions(+), 334 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb1f7..2747c315ad374 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,23 +144,14 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * A prefix used to generate fresh name. - */ - var freshNamePrefix = "" - /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(name: String): String = { - if (freshNamePrefix == "") { - s"$name${curId.getAndIncrement}" - } else { - s"${freshNamePrefix}_$name${curId.getAndIncrement}" - } + def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b94b8..d9fe76133c6ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ef81ba60f049f..57f4945de9804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,11 +22,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -44,16 +42,10 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns the RDD of InternalRow which generates the input rows. + * Returns an input RDD of InternalRow and Java source code to process them. */ - def upstream(): RDD[InternalRow] - - /** - * Returns Java source code to process the rows from upstream. - */ - def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { this.parent = parent - ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -74,41 +66,16 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): String + protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) /** - * Consume the columns generated from current SparkPlan, call it's parent. + * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. */ - def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { - if (input != null) { - assert(input.length == output.length) - } - parent.consumeChild(ctx, this, input, row) + protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { + assert(columns.length == output.length) + parent.doConsume(ctx, this, columns) } - /** - * Consume the columns generated from it's child, call doConsume() or emit the rows. - */ - def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - ctx.freshNamePrefix = nodeName - if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - val evals = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) - } - s""" - | ${evals.map(_.code).mkString("\n")} - | ${doConsume(ctx, evals)} - """.stripMargin - } else { - doConsume(ctx, input) - } - } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -122,9 +89,7 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String } @@ -137,36 +102,31 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def doPrepare(): Unit = { - child.prepare() - } - override def doExecute(): RDD[InternalRow] = { - child.execute() - } + override def supportCodegen: Boolean = true - override def supportCodegen: Boolean = false - - override def upstream(): RDD[InternalRow] = { - child.execute() - } - - override def doProduce(ctx: CodegenContext): String = { + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - s""" - | while (input.hasNext()) { + val code = s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin + (child.execute(), code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -183,20 +143,16 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() ---------> upstream() -------> upstream() ------> execute() - * | - * -----------------> produce() + * doExecute() --------> produce() * | * doProduce() -------> produce() * | - * doProduce() + * doProduce() ---> execute() * | * consume() - * consumeChild() <-----------| + * doConsume() ------------| * | - * doConsume() - * | - * consumeChild() <----- consume() + * doConsume() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -206,48 +162,37 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { - override def supportCodegen: Boolean = false - override def output: Seq[Attribute] = plan.output - override def outputPartitioning: Partitioning = plan.outputPartitioning - override def outputOrdering: Seq[SortOrder] = plan.outputOrdering - - override def doPrepare(): Unit = { - plan.prepare() - } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val code = plan.produce(ctx, this) + val (rdd, code) = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} + private Object[] references; + ${ctx.declareMutableStates()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() throws java.io.IOException { + protected void processNext() { $code - } + } } - """ - + """ // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - plan.upstream().mapPartitions { iter => - + rdd.mapPartitions { iter => val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -258,47 +203,29 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def upstream(): RDD[InternalRow] = { + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { throw new UnsupportedOperationException } - override def doProduce(ctx: CodegenContext): String = { - throw new UnsupportedOperationException - } - - override def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - - if (row != null) { - // There is an UnsafeRow already + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" - | currentRow = $row; + | ${code.code.trim} + | currentRow = ${code.value}; | return; - """.stripMargin + """.stripMargin } else { - assert(input != null) - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns - s""" - | currentRow = unsafeRow; - | return; - """.stripMargin - } + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin } } @@ -319,7 +246,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) + plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -359,14 +286,13 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - (Utils.isTesting || plan.children.exists(supportCodegen)) => + plan.children.exists(supportCodegen) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - val input = apply(p) // collapse them recursively - inputs += input - InputAdapter(input) + inputs += p + InputAdapter(p) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index cbd2634b8900f..23e54f344d252 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,7 +117,9 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && + // final aggregation only have one row, do not need to codegen + !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) } // The variables used as aggregation buffer @@ -125,11 +127,7 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -139,80 +137,50 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - // generate variables for output - val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) { - // evaluate aggregate results - ctx.currentVars = bufVars - val bufferAttrs = functions.flatMap(_.aggBufferAttributes) - val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, bufferAttrs).gen(ctx) - } - // evaluate result expressions - ctx.currentVars = aggResults - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).gen(ctx) - } - (resultVars, s""" - | ${aggResults.map(_.code).mkString("\n")} - | ${resultVars.map(_.code).mkString("\n")} - """.stripMargin) - } else { - // output the aggregate buffer directly - (bufVars, "") - } - - val doAgg = ctx.freshName("doAgg") - ctx.addNewFunction(doAgg, + val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) + val source = s""" - | private void $doAgg() { + | if (!$initAgg) { + | $initAgg = true; + | | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $childSource + | + | // output the result + | ${consume(ctx, bufVars)} | } - """.stripMargin) + """.stripMargin - s""" - | if (!$initAgg) { - | $initAgg = true; - | $doAgg(); - | - | // output the result - | $genResult - | - | ${consume(ctx, resultVars)} - | } - """.stripMargin + (rdd, source) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output - val updateExpr = aggregateExpressions.flatMap { e => - e.mode match { - case Partial | Complete => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions - case PartialMerge | Final => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions - } + // the mode could be only Partial or PartialMerge + val updateExpr = if (modes.contains(Partial)) { + functions.flatMap(_.updateExpressions) + } else { + functions.flatMap(_.mergeExpressions) } + val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val updates = updateExpr.zipWithIndex.map { case (e, i) => - val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) + val codes = boundExpr.zipWithIndex.map { case (e, i) => + val ev = e.gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -222,7 +190,7 @@ case class TungstenAggregate( s""" | // do aggregate and update aggregation buffer - | ${updates.mkString("")} + | ${codes.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e7a73d5fbb4bf..6deb72adad5ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,15 +37,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -80,15 +76,11 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -161,21 +153,17 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - override def upstream(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) - } - - protected override def doProduce(ctx: CodegenContext): String = { - val initTerm = ctx.freshName("initRange") + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initTerm = ctx.freshName("range_initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("partitionEnd") + val partitionEnd = ctx.freshName("range_partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("number") + val number = ctx.freshName("range_number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("overflow") + val overflow = ctx.freshName("range_overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("value") + val value = ctx.freshName("range_value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -184,42 +172,38 @@ case class Range( s"$number > $partitionEnd" } - ctx.addNewFunction("initRange", - s""" - | private void initRange(int idx) { - | $BigInt index = $BigInt.valueOf(idx); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } - | } - """.stripMargin) + val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) - s""" + val code = s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | initRange(((InternalRow) input.next()).getInt(0)); + | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } | } else { | return; | } @@ -234,6 +218,12 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin + + (rdd, code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51a50c1fa30e4..989cb2942918e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,61 +1939,58 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // TODO: support subexpression elimination in whole stage codegen - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) - - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) - } + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82f6811503c23..cbae19ebd269d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,24 +335,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) - } + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 7d6bff8295d2b..d48143762cac0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index a3e5243b68aba..9a24a2487a254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,12 +97,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() - } + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() assert(metrics.length == 3) assert(metrics(0) == 1) From 66449b8dcdbc3dca126c34b42c4d0419c7648696 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Jan 2016 22:20:52 -0800 Subject: [PATCH 011/107] [SPARK-12968][SQL] Implement command to set current database JIRA: https://issues.apache.org/jira/browse/SPARK-12968 Implement command to set current database. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10916 from viirya/ddl-use-database. --- .../spark/sql/catalyst/analysis/Catalog.scala | 4 ++++ .../org/apache/spark/sql/execution/SparkQl.scala | 3 +++ .../apache/spark/sql/execution/commands.scala | 10 ++++++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 2 -- .../spark/sql/hive/client/ClientInterface.scala | 3 +++ .../spark/sql/hive/client/ClientWrapper.scala | 9 +++++++++ .../sql/hive/execution/HiveQuerySuite.scala | 16 ++++++++++++++++ 9 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index a8f89ce6de457..f2f9ec59417ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -46,6 +46,10 @@ trait Catalog { def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan + def setCurrentDatabase(databaseName: String): Unit = { + throw new UnsupportedOperationException + } + /** * Returns tuples of (tableName, isTemporary) for all tables in the given database. * isTemporary is a Boolean value indicates if a table is a temporary or not. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index f6055306b6c97..a5bd8ee42dec9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -55,6 +55,9 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) => + SetDatabaseCommand(cleanIdentifier(database)) + case Token("TOK_DESCTABLE", describeArgs) => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL val Some(tableType) :: formatted :: extended :: pretty :: Nil = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 3cfa3dfd9c7ec..703e4643cbd25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -408,3 +408,13 @@ case class DescribeFunction( } } } + +case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.catalog.setCurrentDatabase(databaseName) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index ab31d45a79a2e..72da266da4d01 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -183,7 +183,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "CREATE DATABASE hive_test_db;" -> "OK", "USE hive_test_db;" - -> "OK", + -> "", "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a9c0e9ab7caef..848aa4ec6fe56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -711,6 +711,10 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } override def unregisterAllTables(): Unit = {} + + override def setCurrentDatabase(databaseName: String): Unit = { + client.setCurrentDatabase(databaseName) + } } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 22841ed2116d1..752c037a842a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -155,8 +155,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_SHOWLOCKS", "TOK_SHOWPARTITIONS", - "TOK_SWITCHDATABASE", - "TOK_UNLOCKTABLE" ) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 9d9a55edd7314..4eec3fef7408b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -109,6 +109,9 @@ private[hive] trait ClientInterface { /** Returns the name of the active database. */ def currentDatabase: String + /** Sets the name of current database. */ + def setCurrentDatabase(databaseName: String): Unit + /** Returns the metadata for specified database, throwing an exception if it doesn't exist */ def getDatabase(name: String): HiveDatabase = { getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index ce7a305d437a5..5307e924e7e55 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} @@ -229,6 +230,14 @@ private[hive] class ClientWrapper( state.getCurrentDatabase } + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + if (getDatabaseOption(databaseName).isDefined) { + state.setCurrentDatabase(databaseName) + } else { + throw new NoSuchDatabaseException + } + } + override def createDatabase(database: HiveDatabase): Unit = withHiveState { client.createDatabase( new Database( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4659d745fe78b..9632d27a2ffce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkException, SparkFiles} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin @@ -1262,6 +1263,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } + test("use database") { + val currentDatabase = sql("select current_database()").first().getString(0) + + sql("CREATE DATABASE hive_test_db") + sql("USE hive_test_db") + assert("hive_test_db" == sql("select current_database()").first().getString(0)) + + intercept[NoSuchDatabaseException] { + sql("USE not_existing_db") + } + + sql(s"USE $currentDatabase") + assert(currentDatabase == sql("select current_database()").first().getString(0)) + } + test("lookup hive UDF in another thread") { val e = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") From 721ced28b522cc00b45ca7fa32a99e80ad3de2f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 Jan 2016 22:42:43 -0800 Subject: [PATCH 012/107] [SPARK-13067] [SQL] workaround for a weird scala reflection problem A simple workaround to avoid getting parameter types when convert a logical plan to json. Author: Wenchen Fan Closes #10970 from cloud-fan/reflection. --- .../spark/sql/catalyst/ScalaReflection.scala | 25 ++++++++++++++++--- .../spark/sql/catalyst/trees/TreeNode.scala | 4 +-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 643228d0eb27d..e5811efb436a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -601,6 +601,20 @@ object ScalaReflection extends ScalaReflection { getConstructorParameters(t) } + /** + * Returns the parameter names for the primary constructor of this class. + * + * Logically we should call `getConstructorParameters` and throw away the parameter types to get + * parameter names, however there are some weird scala reflection problems and this method is a + * workaround to avoid getting parameter types. + */ + def getConstructorParameterNames(cls: Class[_]): Seq[String] = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + val t = classSymbol.selfType + constructParams(t).map(_.name.toString) + } + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } @@ -745,6 +759,12 @@ trait ScalaReflection { def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { val formalTypeArgs = tpe.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = tpe + constructParams(tpe).map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } + + protected def constructParams(tpe: Type): Seq[Symbol] = { val constructorSymbol = tpe.member(nme.CONSTRUCTOR) val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramss @@ -758,9 +778,6 @@ trait ScalaReflection { primaryConstructorSymbol.get.asMethod.paramss } } - - params.flatten.map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - } + params.flatten } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 57e1a3c9eb226..2df0683f9fa16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -512,7 +512,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } protected def jsonFields: List[JField] = { - val fieldNames = getConstructorParameters(getClass).map(_._1) + val fieldNames = getConstructorParameterNames(getClass) val fieldValues = productIterator.toSeq ++ otherCopyArgs assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) @@ -560,7 +560,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper case p: Product => try { - val fieldNames = getConstructorParameters(p.getClass).map(_._1) + val fieldNames = getConstructorParameterNames(p.getClass) val fieldValues = p.productIterator.toSeq assert(fieldNames.length == fieldValues.length) ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { From 8d3cc3de7d116190911e7943ef3233fe3b7db1bf Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Thu, 28 Jan 2016 23:34:50 -0800 Subject: [PATCH 013/107] [SPARK-13050][BUILD] Scalatest tags fail build with the addition of the sketch module A dependency on the spark test tags was left out of the sketch module pom file causing builds to fail when test tags were used. This dependency is found in the pom file for every other module in spark. Author: Alex Bozarth Closes #10954 from ajbozarth/spark13050. --- common/sketch/pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 67723fa421ab1..2cafe8c548f5f 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -35,6 +35,13 @@ sketch + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes From 55561e7693dd2a5bf3c7f8026c725421801fd0ec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 01:59:59 -0800 Subject: [PATCH 014/107] [SPARK-13031][SQL] cleanup codegen and improve test coverage 1. enable whole stage codegen during tests even there is only one operator supports that. 2. split doProduce() into two APIs: upstream() and doProduce() 3. generate prefix for fresh names of each operator 4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again) 5. fix bugs and tests. This PR re-open #10944 and fix the bug. Author: Davies Liu Closes #10977 from davies/gen_refactor. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++++++++------ .../aggregate/AggregationIterator.scala | 2 +- .../aggregate/TungstenAggregate.scala | 98 ++++++--- .../spark/sql/execution/basicOperators.scala | 96 +++++---- .../spark/sql/DataFrameAggregateSuite.scala | 7 + .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 11 files changed, 350 insertions(+), 205 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2747c315ad374..e6704cf8bb1f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,14 +144,23 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * A prefix used to generate fresh name. + */ + var freshNamePrefix = "" + /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" + def freshName(name: String): String = { + if (freshNamePrefix == "") { + s"$name${curId.getAndIncrement}" + } else { + s"${freshNamePrefix}_$name${curId.getAndIncrement}" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d9fe76133c6ef..ec31db19b94b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 57f4945de9804..ef81ba60f049f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns an input RDD of InternalRow and Java source code to process them. + * Returns the RDD of InternalRow which generates the input rows. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + def upstream(): RDD[InternalRow] + + /** + * Returns Java source code to process the rows from upstream. + */ + def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent + ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + protected def doProduce(ctx: CodegenContext): String /** - * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + * Consume the columns generated from current SparkPlan, call it's parent. */ - protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { - assert(columns.length == output.length) - parent.doConsume(ctx, this, columns) + def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + if (input != null) { + assert(input.length == output.length) + } + parent.consumeChild(ctx, this, input, row) } + /** + * Consume the columns generated from it's child, call doConsume() or emit the rows. + */ + def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + ctx.freshNamePrefix = nodeName + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + val evals = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + s""" + | ${evals.map(_.code).mkString("\n")} + | ${doConsume(ctx, evals)} + """.stripMargin + } else { + doConsume(ctx, input) + } + } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } } @@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doPrepare(): Unit = { + child.prepare() + } - override def supportCodegen: Boolean = true + override def doExecute(): RDD[InternalRow] = { + child.execute() + } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def supportCodegen: Boolean = false + + override def upstream(): RDD[InternalRow] = { + child.execute() + } + + override def doProduce(ctx: CodegenContext): String = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - val code = s""" - | while (input.hasNext()) { + s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin - (child.execute(), code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } - - override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() --------> produce() + * doExecute() ---------> upstream() -------> upstream() ------> execute() + * | + * -----------------> produce() * | * doProduce() -------> produce() * | - * doProduce() ---> execute() + * doProduce() * | * consume() - * doConsume() ------------| + * consumeChild() <-----------| * | - * doConsume() <----- consume() + * doConsume() + * | + * consumeChild() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { + override def supportCodegen: Boolean = false + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + + override def doPrepare(): Unit = { + plan.prepare() + } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val (rdd, code) = plan.produce(ctx, this) + val code = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() { + protected void processNext() throws java.io.IOException { $code - } + } } - """ + """ + // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - rdd.mapPartitions { iter => + plan.upstream().mapPartitions { iter => + val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { throw new UnsupportedOperationException } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException + } + + override def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + + if (row != null) { + // There is an UnsafeRow already s""" - | currentRow = unsafeRow; + | currentRow = $row; | return; """.stripMargin + } else { + assert(input != null) + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin + } } } @@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - plan.children.exists(supportCodegen) => + (Utils.isTesting || plan.children.exists(supportCodegen)) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - inputs += p - InputAdapter(p) + val input = apply(p) // collapse them recursively + inputs += input + InputAdapter(input) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 0c74df0aa5fdd..38da82c47ce15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -238,7 +238,7 @@ abstract class AggregationIterator( resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { - // Grouping-only: we only output values of grouping expressions. + // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { resultProjection(currentGroupingKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 23e54f344d252..ff2f38bfd9105 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,9 +117,7 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && - // final aggregation only have one row, do not need to codegen - !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } // The variables used as aggregation buffer @@ -127,7 +125,11 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -137,60 +139,96 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | boolean $isNull = ${ev.isNull}; - | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val source = + // generate variables for output + val bufferAttrs = functions.flatMap(_.aggBufferAttributes) + val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttrs).gen(ctx) + } + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + } + (resultVars, s""" + | ${aggResults.map(_.code).mkString("\n")} + | ${resultVars.map(_.code).mkString("\n")} + """.stripMargin) + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // output the aggregate buffer directly + (bufVars, "") + } else { + // no aggregate function, the result should be literals + val resultVars = resultExpressions.map(_.gen(ctx)) + (resultVars, resultVars.map(_.code).mkString("\n")) + } + + val doAgg = ctx.freshName("doAgg") + ctx.addNewFunction(doAgg, s""" - | if (!$initAgg) { - | $initAgg = true; - | + | private void $doAgg() { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | $childSource - | - | // output the result - | ${consume(ctx, bufVars)} + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } - """.stripMargin + """.stripMargin) - (rdd, source) + s""" + | if (!$initAgg) { + | $initAgg = true; + | $doAgg(); + | + | // output the result + | $genResult + | + | ${consume(ctx, resultVars)} + | } + """.stripMargin } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - // the mode could be only Partial or PartialMerge - val updateExpr = if (modes.contains(Partial)) { - functions.flatMap(_.updateExpressions) - } else { - functions.flatMap(_.mergeExpressions) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } } - val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output - val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val codes = boundExpr.zipWithIndex.map { case (e, i) => - val ev = e.gen(ctx) + val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) + // aggregate buffer should be updated atomic + val updates = aggVals.zipWithIndex.map { case (ev, i) => s""" - | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; | ${bufVars(i).value} = ${ev.value}; """.stripMargin } s""" - | // do aggregate and update aggregation buffer - | ${codes.mkString("")} + | // do aggregate + | ${aggVals.map(_.code).mkString("\n")} + | // update aggregation buffer + | ${updates.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6deb72adad5ec..e7a73d5fbb4bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -153,17 +161,21 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { - val initTerm = ctx.freshName("range_initRange") + override def upstream(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + + protected override def doProduce(ctx: CodegenContext): String = { + val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("range_partitionEnd") + val partitionEnd = ctx.freshName("partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("range_number") + val number = ctx.freshName("number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("range_overflow") + val overflow = ctx.freshName("overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("range_value") + val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -172,38 +184,42 @@ case class Range( s"$number > $partitionEnd" } - val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) + ctx.addNewFunction("initRange", + s""" + | private void initRange(int idx) { + | $BigInt index = $BigInt.valueOf(idx); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | } + """.stripMargin) - val code = s""" + s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } + | initRange(((InternalRow) input.next()).getInt(0)); | } else { | return; | } @@ -218,12 +234,6 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin - - (rdd, code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b1004bc5bc290..08fb7c9d84c0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -153,6 +153,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("agg without groups and functions") { + checkAnswer( + testData2.agg(lit(1)), + Row(1) + ) + } + test("average") { checkAnswer( testData2.agg(avg('a), mean('a)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 989cb2942918e..51a50c1fa30e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) + // TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } - - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cbae19ebd269d..82f6811503c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d48143762cac0..7d6bff8295d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 9a24a2487a254..a3e5243b68aba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + } assert(metrics.length == 3) assert(metrics(0) == 1) From e51b6eaa9e9c007e194d858195291b2b9fb27322 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 29 Jan 2016 09:22:24 -0800 Subject: [PATCH 015/107] [SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example * Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark. * Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community. cc mengxr jkbradley Author: Yanbo Liang Author: Joseph K. Bradley Closes #10469 from yanboliang/spark-11939. --- python/pyspark/ml/param/__init__.py | 24 +++++ python/pyspark/ml/regression.py | 30 +++++- python/pyspark/ml/tests.py | 36 +++++-- python/pyspark/ml/util.py | 142 +++++++++++++++++++++++++++- python/pyspark/ml/wrapper.py | 33 ++++--- 5 files changed, 236 insertions(+), 29 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 3da36d32c5af0..ea86d6aeb8b31 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -314,3 +314,27 @@ def _copyValues(self, to, extra=None): if p in paramMap and to.hasParam(p.name): to._set(**{p.name: paramMap[p]}) return to + + def _resetUid(self, newUid): + """ + Changes the uid of this instance. This updates both + the stored uid and the parent uid of params and param maps. + This is used by persistence (loading). + :param newUid: new uid to use + :return: same instance, but with the uid and Param.parent values + updated, including within param maps + """ + self.uid = newUid + newDefaultParamMap = dict() + newParamMap = dict() + for param in self.params: + newParam = copy.copy(param) + newParam.parent = newUid + if param in self._defaultParamMap: + newDefaultParamMap[newParam] = self._defaultParamMap[param] + if param in self._paramMap: + newParamMap[newParam] = self._paramMap[param] + param.parent = newUid + self._defaultParamMap = newDefaultParamMap + self._paramMap = newParamMap + return self diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 74a2248ed07c8..20dc6c2db91f3 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -18,9 +18,9 @@ import warnings from pyspark import since -from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.mllib.common import inherit_doc @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol): + HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable): """ Linear regression. @@ -68,6 +68,25 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lr_path = path + "/lr" + >>> lr.save(lr_path) + >>> lr2 = LinearRegression.load(lr_path) + >>> lr2.getMaxIter() + 5 + >>> model_path = path + "/lr_model" + >>> model.save(model_path) + >>> model2 = LinearRegressionModel.load(model_path) + >>> model.coefficients[0] == model2.coefficients[0] + True + >>> model.intercept == model2.intercept + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.4.0 """ @@ -106,7 +125,7 @@ def _create_model(self, java_model): return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel): +class LinearRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by LinearRegression. @@ -821,9 +840,10 @@ def predict(self, features): if __name__ == "__main__": import doctest + import pyspark.ml.regression from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.regression.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.regression tests") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c45a159c460f3..54806ee336666 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -34,18 +34,22 @@ else: import unittest -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext, Row -from pyspark.sql.functions import rand +from shutil import rmtree +import tempfile + +from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.feature import * from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed -from pyspark.ml.util import keyword_only -from pyspark.ml import Estimator, Model, Pipeline, Transformer -from pyspark.ml.feature import * +from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel +from pyspark.ml.util import keyword_only from pyspark.mllib.linalg import DenseVector +from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.sql.functions import rand +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase class MockDataset(DataFrame): @@ -405,6 +409,26 @@ def test_fit_maximize_metric(self): self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") +class PersistenceTest(PySparkTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index cee9d67b05325..d7a813f56cd57 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,8 +15,27 @@ # limitations under the License. # -from functools import wraps +import sys import uuid +from functools import wraps + +if sys.version > '3': + basestring = str + +from pyspark import SparkContext, since +from pyspark.mllib.common import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") def keyword_only(func): @@ -52,3 +71,124 @@ def _randomUID(cls): concatenates the class name, "_", and 12 random hex chars. """ return cls.__name__ + "_" + uuid.uuid4().hex[12:] + + +@inherit_doc +class JavaMLWriter(object): + """ + .. note:: Experimental + + Utility class that can save ML instances through their Scala implementation. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, instance): + instance._transfer_params_to_java() + self._jwrite = instance._java_obj.write() + + def save(self, path): + """Save the ML instance to the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._jwrite.save(path) + + def overwrite(self): + """Overwrites if the output path already exists.""" + self._jwrite.overwrite() + return self + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + self._jwrite.context(sqlContext._ssql_ctx) + return self + + +@inherit_doc +class MLWritable(object): + """ + .. note:: Experimental + + Mixin for ML instances that provide JavaMLWriter. + + .. versionadded:: 2.0.0 + """ + + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + +@inherit_doc +class JavaMLReader(object): + """ + .. note:: Experimental + + Utility class that can load ML instances through their Scala implementation. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, clazz): + self._clazz = clazz + self._jread = self._load_java_obj(clazz).read() + + def load(self, path): + """Load the ML instance from the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + java_obj = self._jread.load(path) + instance = self._clazz() + instance._java_obj = java_obj + instance._resetUid(java_obj.uid()) + instance._transfer_params_from_java() + return instance + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + self._jread.context(sqlContext._ssql_ctx) + return self + + @classmethod + def _java_loader_class(cls, clazz): + """ + Returns the full class name of the Java ML instance. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ + java_package = clazz.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, clazz.__name__]) + + @classmethod + def _load_java_obj(cls, clazz): + """Load the peer Java object of the ML instance.""" + java_class = cls._java_loader_class(clazz) + java_obj = _jvm() + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj + + +@inherit_doc +class MLReadable(object): + """ + .. note:: Experimental + + Mixin for instances that provide JavaMLReader. + + .. versionadded:: 2.0.0 + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def load(cls, path): + """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" + return cls.read().load(path) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index dd1d4b076eddd..d4d48eb2150e3 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,21 +21,10 @@ from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Model +from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -def _jvm(): - """ - Returns the JVM view associated with SparkContext. Must be called - after SparkContext is initialized. - """ - jvm = SparkContext._jvm - if jvm: - return jvm - else: - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") - - @inherit_doc class JavaWrapper(Params): """ @@ -159,15 +148,24 @@ class JavaModel(Model, JavaTransformer): __metaclass__ = ABCMeta - def __init__(self, java_model): + def __init__(self, java_model=None): """ Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, and then call _transformer_params_from_java. + + This instance can be instantiated without specifying java_model, + it will be assigned after that, but this scenario only used by + :py:class:`JavaMLReader` to load models. This is a bit of a + hack, but it is easiest since a proper fix would require + MLReader (in pyspark.ml.util) to depend on these wrappers, but + these wrappers depend on pyspark.ml.util (both directly and via + other ML classes). """ super(JavaModel, self).__init__() - self._java_obj = java_model - self.uid = java_model.uid() + if java_model is not None: + self._java_obj = java_model + self.uid = java_model.uid() def copy(self, extra=None): """ @@ -182,8 +180,9 @@ def copy(self, extra=None): if extra is None: extra = dict() that = super(JavaModel, self).copy(extra) - that._java_obj = self._java_obj.copy(self._empty_java_param_map()) - that._transfer_params_to_java() + if self._java_obj is not None: + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() return that def _call_java(self, name, *args): From e4c1162b6b3dbc8fc95cfe75c6e0bc2915575fb2 Mon Sep 17 00:00:00 2001 From: zhuol Date: Fri, 29 Jan 2016 11:54:58 -0600 Subject: [PATCH 016/107] [SPARK-10873] Support column sort and search for History Server. [SPARK-10873] Support column sort and search for History Server using jQuery DataTable and REST API. Before this commit, the history server was generated hard-coded html and can not support search, also, the sorting was disabled if there is any application that has more than one attempt. Supporting search and sort (over all applications rather than the 20 entries in the current page) in any case will greatly improve user experience. 1. Create the historypage-template.html for displaying application information in datables. 2. historypage.js uses jQuery to access the data from /api/v1/applications REST API, and use DataTable to display each application's information. For application that has more than one attempt, the RowsGroup is used to merge such entries while at the same time supporting sort and search. 3. "duration" and "lastUpdated" rest API are added to application's "attempts". 4. External javascirpt and css files for datatables, RowsGroup and jquery plugins are added with licenses clarified. Snapshots for how it looks like now: History page view: ![historypage](https://cloud.githubusercontent.com/assets/11683054/12184383/89bad774-b55a-11e5-84e4-b0276172976f.png) Search: ![search](https://cloud.githubusercontent.com/assets/11683054/12184385/8d3b94b0-b55a-11e5-869a-cc0ef0a4242a.png) Sort by started time: ![sort-by-started-time](https://cloud.githubusercontent.com/assets/11683054/12184387/8f757c3c-b55a-11e5-98c8-577936366566.png) Author: zhuol Closes #10648 from zhuoliu/10873. --- .rat-excludes | 10 + LICENSE | 6 + .../spark/ui/static/dataTables.bootstrap.css | 319 ++++++++++ .../ui/static/dataTables.bootstrap.min.js | 8 + .../spark/ui/static/dataTables.rowsGroup.js | 224 +++++++ .../spark/ui/static/historypage-template.html | 81 +++ .../org/apache/spark/ui/static/historypage.js | 159 +++++ .../spark/ui/static/jquery.blockUI.min.js | 6 + .../ui/static/jquery.cookies.2.2.0.min.js | 18 + .../static/jquery.dataTables.1.10.4.min.css | 1 + .../ui/static/jquery.dataTables.1.10.4.min.js | 157 +++++ .../apache/spark/ui/static/jquery.mustache.js | 592 ++++++++++++++++++ .../spark/ui/static/jsonFormatter.min.css | 1 + .../spark/ui/static/jsonFormatter.min.js | 2 + .../spark/deploy/history/HistoryPage.scala | 193 +----- .../api/v1/ApplicationListResource.scala | 14 + .../org/apache/spark/status/api/v1/api.scala | 2 + .../scala/org/apache/spark/ui/SparkUI.scala | 2 + .../scala/org/apache/spark/ui/UIUtils.scala | 11 + .../application_list_json_expectation.json | 18 +- .../completed_app_list_json_expectation.json | 18 +- .../maxDate2_app_list_json_expectation.json | 4 +- .../maxDate_app_list_json_expectation.json | 6 +- .../minDate_app_list_json_expectation.json | 14 +- .../one_app_json_expectation.json | 4 +- ...ne_app_multi_attempt_json_expectation.json | 6 +- .../deploy/history/HistoryServerSuite.scala | 43 +- project/MimaExcludes.scala | 4 + 28 files changed, 1721 insertions(+), 202 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage-template.html create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js create mode 100755 core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.css create mode 100755 core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.js diff --git a/.rat-excludes b/.rat-excludes index a4f316a4aaa04..874a6ee9f4043 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -25,6 +25,16 @@ graphlib-dot.min.js sorttable.js vis.min.js vis.min.css +dataTables.bootstrap.css +dataTables.bootstrap.min.js +dataTables.rowsGroup.js +jquery.blockUI.min.js +jquery.cookies.2.2.0.min.js +jquery.dataTables.1.10.4.min.css +jquery.dataTables.1.10.4.min.js +jquery.mustache.js +jsonFormatter.min.css +jsonFormatter.min.js .*avsc .*txt .*json diff --git a/LICENSE b/LICENSE index 9c944ac610afe..9fc29db8d3f22 100644 --- a/LICENSE +++ b/LICENSE @@ -291,3 +291,9 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) + (MIT License) datatables (http://datatables.net/license) + (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) + (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) + (MIT License) blockUI (http://jquery.malsup.com/block/) + (MIT License) RowsGroup (http://datatables.net/license/mit) + (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css new file mode 100644 index 0000000000000..faee0e50dbfea --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css @@ -0,0 +1,319 @@ +div.dataTables_length label { + font-weight: normal; + text-align: left; + white-space: nowrap; +} + +div.dataTables_length select { + width: 75px; + display: inline-block; +} + +div.dataTables_filter { + text-align: right; +} + +div.dataTables_filter label { + font-weight: normal; + white-space: nowrap; + text-align: left; +} + +div.dataTables_filter input { + margin-left: 0.5em; + display: inline-block; +} + +div.dataTables_info { + padding-top: 8px; + white-space: nowrap; +} + +div.dataTables_paginate { + margin: 0; + white-space: nowrap; + text-align: right; +} + +div.dataTables_paginate ul.pagination { + margin: 2px 0; + white-space: nowrap; +} + +@media screen and (max-width: 767px) { + div.dataTables_length, + div.dataTables_filter, + div.dataTables_info, + div.dataTables_paginate { + text-align: center; + } +} + + +table.dataTable td, +table.dataTable th { + -webkit-box-sizing: content-box; + -moz-box-sizing: content-box; + box-sizing: content-box; +} + + +table.dataTable { + clear: both; + margin-top: 6px !important; + margin-bottom: 6px !important; + max-width: none !important; +} + +table.dataTable thead .sorting, +table.dataTable thead .sorting_asc, +table.dataTable thead .sorting_desc, +table.dataTable thead .sorting_asc_disabled, +table.dataTable thead .sorting_desc_disabled { + cursor: pointer; +} + +table.dataTable thead .sorting { background: url('../images/sort_both.png') no-repeat center right; } +table.dataTable thead .sorting_asc { background: url('../images/sort_asc.png') no-repeat center right; } +table.dataTable thead .sorting_desc { background: url('../images/sort_desc.png') no-repeat center right; } + +table.dataTable thead .sorting_asc_disabled { background: url('../images/sort_asc_disabled.png') no-repeat center right; } +table.dataTable thead .sorting_desc_disabled { background: url('../images/sort_desc_disabled.png') no-repeat center right; } + +table.dataTable thead > tr > th { + padding-left: 18px; + padding-right: 18px; +} + +table.dataTable th:active { + outline: none; +} + +/* Scrolling */ +div.dataTables_scrollHead table { + margin-bottom: 0 !important; + border-bottom-left-radius: 0; + border-bottom-right-radius: 0; +} + +div.dataTables_scrollHead table thead tr:last-child th:first-child, +div.dataTables_scrollHead table thead tr:last-child td:first-child { + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.dataTables_scrollBody table { + border-top: none; + margin-top: 0 !important; + margin-bottom: 0 !important; +} + +div.dataTables_scrollBody tbody tr:first-child th, +div.dataTables_scrollBody tbody tr:first-child td { + border-top: none; +} + +div.dataTables_scrollFoot table { + margin-top: 0 !important; + border-top: none; +} + +/* Frustratingly the border-collapse:collapse used by Bootstrap makes the column + width calculations when using scrolling impossible to align columns. We have + to use separate + */ +table.table-bordered.dataTable { + border-collapse: separate !important; +} +table.table-bordered thead th, +table.table-bordered thead td { + border-left-width: 0; + border-top-width: 0; +} +table.table-bordered tbody th, +table.table-bordered tbody td { + border-left-width: 0; + border-bottom-width: 0; +} +table.table-bordered th:last-child, +table.table-bordered td:last-child { + border-right-width: 0; +} +div.dataTables_scrollHead table.table-bordered { + border-bottom-width: 0; +} + + + + +/* + * TableTools styles + */ +.table.dataTable tbody tr.active td, +.table.dataTable tbody tr.active th { + background-color: #08C; + color: white; +} + +.table.dataTable tbody tr.active:hover td, +.table.dataTable tbody tr.active:hover th { + background-color: #0075b0 !important; +} + +.table.dataTable tbody tr.active th > a, +.table.dataTable tbody tr.active td > a { + color: white; +} + +.table-striped.dataTable tbody tr.active:nth-child(odd) td, +.table-striped.dataTable tbody tr.active:nth-child(odd) th { + background-color: #017ebc; +} + +table.DTTT_selectable tbody tr { + cursor: pointer; +} + +div.DTTT .btn { + color: #333 !important; + font-size: 12px; +} + +div.DTTT .btn:hover { + text-decoration: none !important; +} + +ul.DTTT_dropdown.dropdown-menu { + z-index: 2003; +} + +ul.DTTT_dropdown.dropdown-menu a { + color: #333 !important; /* needed only when demo_page.css is included */ +} + +ul.DTTT_dropdown.dropdown-menu li { + position: relative; +} + +ul.DTTT_dropdown.dropdown-menu li:hover a { + background-color: #0088cc; + color: white !important; +} + +div.DTTT_collection_background { + z-index: 2002; +} + +/* TableTools information display */ +div.DTTT_print_info { + position: fixed; + top: 50%; + left: 50%; + width: 400px; + height: 150px; + margin-left: -200px; + margin-top: -75px; + text-align: center; + color: #333; + padding: 10px 30px; + opacity: 0.95; + + background-color: white; + border: 1px solid rgba(0, 0, 0, 0.2); + border-radius: 6px; + + -webkit-box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5); + box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5); +} + +div.DTTT_print_info h6 { + font-weight: normal; + font-size: 28px; + line-height: 28px; + margin: 1em; +} + +div.DTTT_print_info p { + font-size: 14px; + line-height: 20px; +} + +div.dataTables_processing { + position: absolute; + top: 50%; + left: 50%; + width: 100%; + height: 60px; + margin-left: -50%; + margin-top: -25px; + padding-top: 20px; + padding-bottom: 20px; + text-align: center; + font-size: 1.2em; + background-color: white; + background: -webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0))); + background: -webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); +} + + + +/* + * FixedColumns styles + */ +div.DTFC_LeftHeadWrapper table, +div.DTFC_LeftFootWrapper table, +div.DTFC_RightHeadWrapper table, +div.DTFC_RightFootWrapper table, +table.DTFC_Cloned tr.even { + background-color: white; + margin-bottom: 0; +} + +div.DTFC_RightHeadWrapper table , +div.DTFC_LeftHeadWrapper table { + border-bottom: none !important; + margin-bottom: 0 !important; + border-top-right-radius: 0 !important; + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.DTFC_RightHeadWrapper table thead tr:last-child th:first-child, +div.DTFC_RightHeadWrapper table thead tr:last-child td:first-child, +div.DTFC_LeftHeadWrapper table thead tr:last-child th:first-child, +div.DTFC_LeftHeadWrapper table thead tr:last-child td:first-child { + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.DTFC_RightBodyWrapper table, +div.DTFC_LeftBodyWrapper table { + border-top: none; + margin: 0 !important; +} + +div.DTFC_RightBodyWrapper tbody tr:first-child th, +div.DTFC_RightBodyWrapper tbody tr:first-child td, +div.DTFC_LeftBodyWrapper tbody tr:first-child th, +div.DTFC_LeftBodyWrapper tbody tr:first-child td { + border-top: none; +} + +div.DTFC_RightFootWrapper table, +div.DTFC_LeftFootWrapper table { + border-top: none; + margin-top: 0 !important; +} + + +/* + * FixedHeader styles + */ +div.FixedHeader_Cloned table { + margin: 0 !important +} + diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js new file mode 100644 index 0000000000000..f0d09b9d52668 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js @@ -0,0 +1,8 @@ +/*! + DataTables Bootstrap 3 integration + ©2011-2014 SpryMedia Ltd - datatables.net/license +*/ +(function(){var f=function(c,b){c.extend(!0,b.defaults,{dom:"<'row'<'col-sm-6'l><'col-sm-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-6'i><'col-sm-6'p>>",renderer:"bootstrap"});c.extend(b.ext.classes,{sWrapper:"dataTables_wrapper form-inline dt-bootstrap",sFilterInput:"form-control input-sm",sLengthSelect:"form-control input-sm"});b.ext.renderer.pageButton.bootstrap=function(g,f,p,k,h,l){var q=new b.Api(g),r=g.oClasses,i=g.oLanguage.oPaginate,d,e,o=function(b,f){var j,m,n,a,k=function(a){a.preventDefault(); +c(a.currentTarget).hasClass("disabled")||q.page(a.data.action).draw(!1)};j=0;for(m=f.length;j",{"class":r.sPageButton+" "+ +e,"aria-controls":g.sTableId,tabindex:g.iTabIndex,id:0===p&&"string"===typeof a?g.sTableId+"_"+a:null}).append(c("",{href:"#"}).html(d)).appendTo(b),g.oApi._fnBindAction(n,{action:a},k))}};o(c(f).empty().html('
      ').children("ul"),k)};b.TableTools&&(c.extend(!0,b.TableTools.classes,{container:"DTTT btn-group",buttons:{normal:"btn btn-default",disabled:"disabled"},collection:{container:"DTTT_dropdown dropdown-menu",buttons:{normal:"",disabled:"disabled"}},print:{info:"DTTT_print_info"}, +select:{row:"active"}}),c.extend(!0,b.TableTools.DEFAULTS.oTags,{collection:{container:"ul",button:"li",liner:"a"}}))};"function"===typeof define&&define.amd?define(["jquery","datatables"],f):"object"===typeof exports?f(require("jquery"),require("datatables")):jQuery&&f(jQuery,jQuery.fn.dataTable)})(window,document); diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js new file mode 100644 index 0000000000000..983c3a564fb1b --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js @@ -0,0 +1,224 @@ +/*! RowsGroup for DataTables v1.0.0 + * 2015 Alexey Shildyakov ashl1future@gmail.com + */ + +/** + * @summary RowsGroup + * @description Group rows by specified columns + * @version 1.0.0 + * @file dataTables.rowsGroup.js + * @author Alexey Shildyakov (ashl1future@gmail.com) + * @contact ashl1future@gmail.com + * @copyright Alexey Shildyakov + * + * License MIT - http://datatables.net/license/mit + * + * This feature plug-in for DataTables automatically merges columns cells + * based on it's values equality. It supports multi-column row grouping + * in according to the requested order with dependency from each previous + * requested columns. Now it supports ordering and searching. + * Please see the example.html for details. + * + * Rows grouping in DataTables can be enabled by using any one of the following + * options: + * + * * Setting the `rowsGroup` parameter in the DataTables initialisation + * to array which contains columns selectors + * (https://datatables.net/reference/type/column-selector) used for grouping. i.e. + * rowsGroup = [1, 'columnName:name', ] + * * Setting the `rowsGroup` parameter in the DataTables defaults + * (thus causing all tables to have this feature) - i.e. + * `$.fn.dataTable.defaults.RowsGroup = [0]`. + * * Creating a new instance: `new $.fn.dataTable.RowsGroup( table, columnsForGrouping );` + * where `table` is a DataTable's API instance and `columnsForGrouping` is the array + * described above. + * + * For more detailed information please see: + * + */ + +(function($){ + +ShowedDataSelectorModifier = { + order: 'current', + page: 'current', + search: 'applied', +} + +GroupedColumnsOrderDir = 'desc'; // change + + +/* + * columnsForGrouping: array of DTAPI:cell-selector for columns for which rows grouping is applied + */ +var RowsGroup = function ( dt, columnsForGrouping ) +{ + this.table = dt.table(); + this.columnsForGrouping = columnsForGrouping; + // set to True when new reorder is applied by RowsGroup to prevent order() looping + this.orderOverrideNow = false; + this.order = [] + + self = this; + $(document).on('order.dt', function ( e, settings) { + if (!self.orderOverrideNow) { + self._updateOrderAndDraw() + } + self.orderOverrideNow = false; + }) + + $(document).on('draw.dt', function ( e, settings) { + self._mergeCells() + }) + + this._updateOrderAndDraw(); +}; + + +RowsGroup.prototype = { + _getOrderWithGroupColumns: function (order, groupedColumnsOrderDir) + { + if (groupedColumnsOrderDir === undefined) + groupedColumnsOrderDir = GroupedColumnsOrderDir + + var self = this; + var groupedColumnsIndexes = this.columnsForGrouping.map(function(columnSelector){ + return self.table.column(columnSelector).index() + }) + var groupedColumnsKnownOrder = order.filter(function(columnOrder){ + return groupedColumnsIndexes.indexOf(columnOrder[0]) >= 0 + }) + var nongroupedColumnsOrder = order.filter(function(columnOrder){ + return groupedColumnsIndexes.indexOf(columnOrder[0]) < 0 + }) + var groupedColumnsKnownOrderIndexes = groupedColumnsKnownOrder.map(function(columnOrder){ + return columnOrder[0] + }) + var groupedColumnsOrder = groupedColumnsIndexes.map(function(iColumn){ + var iInOrderIndexes = groupedColumnsKnownOrderIndexes.indexOf(iColumn) + if (iInOrderIndexes >= 0) + return [iColumn, groupedColumnsKnownOrder[iInOrderIndexes][1]] + else + return [iColumn, groupedColumnsOrderDir] + }) + + groupedColumnsOrder.push.apply(groupedColumnsOrder, nongroupedColumnsOrder) + return groupedColumnsOrder; + }, + + // Workaround: the DT reset ordering to 'desc' from multi-ordering if user order on one column (without shift) + // but because we always has multi-ordering due to grouped rows this happens every time + _getInjectedMonoSelectWorkaround: function(order) + { + if (order.length === 1) { + // got mono order - workaround here + var orderingColumn = order[0][0] + var previousOrder = this.order.map(function(val){ + return val[0] + }) + var iColumn = previousOrder.indexOf(orderingColumn); + if (iColumn >= 0) { + // assume change the direction, because we already has that in previous order + return [[orderingColumn, this._toogleDirection(this.order[iColumn][1])]] + } // else This is the new ordering column. Proceed as is. + } // else got multi order - work normal + return order; + }, + + _mergeCells: function() + { + var columnsIndexes = this.table.columns(this.columnsForGrouping, ShowedDataSelectorModifier).indexes().toArray() + var showedRowsCount = this.table.rows(ShowedDataSelectorModifier)[0].length + this._mergeColumn(0, showedRowsCount - 1, columnsIndexes) + }, + + // the index is relative to the showed data + // (selector-modifier = {order: 'current', page: 'current', search: 'applied'}) index + _mergeColumn: function(iStartRow, iFinishRow, columnsIndexes) + { + var columnsIndexesCopy = columnsIndexes.slice() + currentColumn = columnsIndexesCopy.shift() + currentColumn = this.table.column(currentColumn, ShowedDataSelectorModifier) + + var columnNodes = currentColumn.nodes() + var columnValues = currentColumn.data() + + var newSequenceRow = iStartRow, + iRow; + for (iRow = iStartRow + 1; iRow <= iFinishRow; ++iRow) { + + if (columnValues[iRow] === columnValues[newSequenceRow]) { + $(columnNodes[iRow]).hide() + } else { + $(columnNodes[newSequenceRow]).show() + $(columnNodes[newSequenceRow]).attr('rowspan', (iRow-1) - newSequenceRow + 1) + + if (columnsIndexesCopy.length > 0) + this._mergeColumn(newSequenceRow, (iRow-1), columnsIndexesCopy) + + newSequenceRow = iRow; + } + + } + $(columnNodes[newSequenceRow]).show() + $(columnNodes[newSequenceRow]).attr('rowspan', (iRow-1)- newSequenceRow + 1) + if (columnsIndexesCopy.length > 0) + this._mergeColumn(newSequenceRow, (iRow-1), columnsIndexesCopy) + }, + + _toogleDirection: function(dir) + { + return dir == 'asc'? 'desc': 'asc'; + }, + + _updateOrderAndDraw: function() + { + this.orderOverrideNow = true; + + var currentOrder = this.table.order(); + currentOrder = this._getInjectedMonoSelectWorkaround(currentOrder); + this.order = this._getOrderWithGroupColumns(currentOrder) + // this.table.order($.extend(true, Array(), this.order)) // disable this line in order to support sorting on non-grouped columns + this.table.draw(false) + }, +}; + + +$.fn.dataTable.RowsGroup = RowsGroup; +$.fn.DataTable.RowsGroup = RowsGroup; + +// Automatic initialisation listener +$(document).on( 'init.dt', function ( e, settings ) { + if ( e.namespace !== 'dt' ) { + return; + } + + var api = new $.fn.dataTable.Api( settings ); + + if ( settings.oInit.rowsGroup || + $.fn.dataTable.defaults.rowsGroup ) + { + options = settings.oInit.rowsGroup? + settings.oInit.rowsGroup: + $.fn.dataTable.defaults.rowsGroup; + new RowsGroup( api, options ); + } +} ); + +}(jQuery)); + +/* + +TODO: Provide function which determines the all s and s with "rowspan" html-attribute is parent (groupped) for the specified or . To use in selections, editing or hover styles. + +TODO: Feature +Use saved order direction for grouped columns + Split the columns into grouped and ungrouped. + + user = grouped+ungrouped + grouped = grouped + saved = grouped+ungrouped + + For grouped uses following order: user -> saved (because 'saved' include 'grouped' after first initialisation). This should be done with saving order like for 'groupedColumns' + For ungrouped: uses only 'user' input ordering +*/ \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html new file mode 100644 index 0000000000000..66d111e439096 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -0,0 +1,81 @@ + + + diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js new file mode 100644 index 0000000000000..785abe45bc56e --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// this function works exactly the same as UIUtils.formatDuration +function formatDuration(milliseconds) { + if (milliseconds < 100) { + return milliseconds + " ms"; + } + var seconds = milliseconds * 1.0 / 1000; + if (seconds < 1) { + return seconds.toFixed(1) + " s"; + } + if (seconds < 60) { + return seconds.toFixed(0) + " s"; + } + var minutes = seconds / 60; + if (minutes < 10) { + return minutes.toFixed(1) + " min"; + } else if (minutes < 60) { + return minutes.toFixed(0) + " min"; + } + var hours = minutes / 60; + return hours.toFixed(1) + " h"; +} + +function formatDate(date) { + return date.split(".")[0].replace("T", " "); +} + +function getParameterByName(name, searchString) { + var regex = new RegExp("[\\?&]" + name + "=([^&#]*)"), + results = regex.exec(searchString); + return results === null ? "" : decodeURIComponent(results[1].replace(/\+/g, " ")); +} + +jQuery.extend( jQuery.fn.dataTableExt.oSort, { + "title-numeric-pre": function ( a ) { + var x = a.match(/title="*(-?[0-9\.]+)/)[1]; + return parseFloat( x ); + }, + + "title-numeric-asc": function ( a, b ) { + return ((a < b) ? -1 : ((a > b) ? 1 : 0)); + }, + + "title-numeric-desc": function ( a, b ) { + return ((a < b) ? 1 : ((a > b) ? -1 : 0)); + } +} ); + +$(document).ajaxStop($.unblockUI); +$(document).ajaxStart(function(){ + $.blockUI({ message: '

      Loading history summary...

      '}); +}); + +$(document).ready(function() { + $.extend( $.fn.dataTable.defaults, { + stateSave: true, + lengthMenu: [[20,40,60,100,-1], [20, 40, 60, 100, "All"]], + pageLength: 20 + }); + + historySummary = $("#history-summary"); + searchString = historySummary["context"]["location"]["search"]; + requestedIncomplete = getParameterByName("showIncomplete", searchString); + requestedIncomplete = (requestedIncomplete == "true" ? true : false); + + $.getJSON("/api/v1/applications", function(response,status,jqXHR) { + var array = []; + var hasMultipleAttempts = false; + for (i in response) { + var app = response[i]; + if (app["attempts"][0]["completed"] == requestedIncomplete) { + continue; // if we want to show for Incomplete, we skip the completed apps; otherwise skip incomplete ones. + } + var id = app["id"]; + var name = app["name"]; + if (app["attempts"].length > 1) { + hasMultipleAttempts = true; + } + var num = app["attempts"].length; + for (j in app["attempts"]) { + var attempt = app["attempts"][j]; + attempt["startTime"] = formatDate(attempt["startTime"]); + attempt["endTime"] = formatDate(attempt["endTime"]); + attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; + array.push(app_clone); + } + } + + var data = {"applications": array} + $.get("/static/historypage-template.html", function(template) { + historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); + var selector = "#history-summary-table"; + var conf = { + "columns": [ + {name: 'first'}, + {name: 'second'}, + {name: 'third'}, + {name: 'fourth'}, + {name: 'fifth'}, + {name: 'sixth', type: "title-numeric"}, + {name: 'seventh'}, + {name: 'eighth'}, + ], + }; + + var rowGroupConf = { + "rowsGroup": [ + 'first:name', + 'second:name' + ], + }; + + if (hasMultipleAttempts) { + jQuery.extend(conf, rowGroupConf); + var rowGroupCells = document.getElementsByClassName("rowGroupColumn"); + for (i = 0; i < rowGroupCells.length; i++) { + rowGroupCells[i].style='background-color: #ffffff'; + } + } + + if (!hasMultipleAttempts) { + var attemptIDCells = document.getElementsByClassName("attemptIDSpan"); + for (i = 0; i < attemptIDCells.length; i++) { + attemptIDCells[i].style.display='none'; + } + } + + var durationCells = document.getElementsByClassName("durationClass"); + for (i = 0; i < durationCells.length; i++) { + var timeInMilliseconds = parseInt(durationCells[i].title); + durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + } + + if ($(selector.concat(" tr")).length < 20) { + $.extend(conf, {paging: false}); + } + + $(selector).DataTable(conf); + $('#hisotry-summary [data-toggle="tooltip"]').tooltip(); + }); + }); +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js new file mode 100644 index 0000000000000..1e84b3ec21c4d --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js @@ -0,0 +1,6 @@ +/* +* jQuery BlockUI; v20131009 +* http://jquery.malsup.com/block/ +* Copyright (c) 2013 M. Alsup; Dual licensed: MIT/GPL +*/ +(function(){"use strict";function e(e){function o(o,i){var s,h,k=o==window,v=i&&void 0!==i.message?i.message:void 0;if(i=e.extend({},e.blockUI.defaults,i||{}),!i.ignoreIfBlocked||!e(o).data("blockUI.isBlocked")){if(i.overlayCSS=e.extend({},e.blockUI.defaults.overlayCSS,i.overlayCSS||{}),s=e.extend({},e.blockUI.defaults.css,i.css||{}),i.onOverlayClick&&(i.overlayCSS.cursor="pointer"),h=e.extend({},e.blockUI.defaults.themedCSS,i.themedCSS||{}),v=void 0===v?i.message:v,k&&b&&t(window,{fadeOut:0}),v&&"string"!=typeof v&&(v.parentNode||v.jquery)){var y=v.jquery?v[0]:v,m={};e(o).data("blockUI.history",m),m.el=y,m.parent=y.parentNode,m.display=y.style.display,m.position=y.style.position,m.parent&&m.parent.removeChild(y)}e(o).data("blockUI.onUnblock",i.onUnblock);var g,I,w,U,x=i.baseZ;g=r||i.forceIframe?e(''):e(''),I=i.theme?e(''):e(''),i.theme&&k?(U='"):i.theme?(U='"):U=k?'':'',w=e(U),v&&(i.theme?(w.css(h),w.addClass("ui-widget-content")):w.css(s)),i.theme||I.css(i.overlayCSS),I.css("position",k?"fixed":"absolute"),(r||i.forceIframe)&&g.css("opacity",0);var C=[g,I,w],S=k?e("body"):e(o);e.each(C,function(){this.appendTo(S)}),i.theme&&i.draggable&&e.fn.draggable&&w.draggable({handle:".ui-dialog-titlebar",cancel:"li"});var O=f&&(!e.support.boxModel||e("object,embed",k?null:o).length>0);if(u||O){if(k&&i.allowBodyStretch&&e.support.boxModel&&e("html,body").css("height","100%"),(u||!e.support.boxModel)&&!k)var E=d(o,"borderTopWidth"),T=d(o,"borderLeftWidth"),M=E?"(0 - "+E+")":0,B=T?"(0 - "+T+")":0;e.each(C,function(e,o){var t=o[0].style;if(t.position="absolute",2>e)k?t.setExpression("height","Math.max(document.body.scrollHeight, document.body.offsetHeight) - (jQuery.support.boxModel?0:"+i.quirksmodeOffsetHack+') + "px"'):t.setExpression("height",'this.parentNode.offsetHeight + "px"'),k?t.setExpression("width",'jQuery.support.boxModel && document.documentElement.clientWidth || document.body.clientWidth + "px"'):t.setExpression("width",'this.parentNode.offsetWidth + "px"'),B&&t.setExpression("left",B),M&&t.setExpression("top",M);else if(i.centerY)k&&t.setExpression("top",'(document.documentElement.clientHeight || document.body.clientHeight) / 2 - (this.offsetHeight / 2) + (blah = document.documentElement.scrollTop ? document.documentElement.scrollTop : document.body.scrollTop) + "px"'),t.marginTop=0;else if(!i.centerY&&k){var n=i.css&&i.css.top?parseInt(i.css.top,10):0,s="((document.documentElement.scrollTop ? document.documentElement.scrollTop : document.body.scrollTop) + "+n+') + "px"';t.setExpression("top",s)}})}if(v&&(i.theme?w.find(".ui-widget-content").append(v):w.append(v),(v.jquery||v.nodeType)&&e(v).show()),(r||i.forceIframe)&&i.showOverlay&&g.show(),i.fadeIn){var j=i.onBlock?i.onBlock:c,H=i.showOverlay&&!v?j:c,z=v?j:c;i.showOverlay&&I._fadeIn(i.fadeIn,H),v&&w._fadeIn(i.fadeIn,z)}else i.showOverlay&&I.show(),v&&w.show(),i.onBlock&&i.onBlock();if(n(1,o,i),k?(b=w[0],p=e(i.focusableElements,b),i.focusInput&&setTimeout(l,20)):a(w[0],i.centerX,i.centerY),i.timeout){var W=setTimeout(function(){k?e.unblockUI(i):e(o).unblock(i)},i.timeout);e(o).data("blockUI.timeout",W)}}}function t(o,t){var s,l=o==window,a=e(o),d=a.data("blockUI.history"),c=a.data("blockUI.timeout");c&&(clearTimeout(c),a.removeData("blockUI.timeout")),t=e.extend({},e.blockUI.defaults,t||{}),n(0,o,t),null===t.onUnblock&&(t.onUnblock=a.data("blockUI.onUnblock"),a.removeData("blockUI.onUnblock"));var r;r=l?e("body").children().filter(".blockUI").add("body > .blockUI"):a.find(">.blockUI"),t.cursorReset&&(r.length>1&&(r[1].style.cursor=t.cursorReset),r.length>2&&(r[2].style.cursor=t.cursorReset)),l&&(b=p=null),t.fadeOut?(s=r.length,r.stop().fadeOut(t.fadeOut,function(){0===--s&&i(r,d,t,o)})):i(r,d,t,o)}function i(o,t,i,n){var s=e(n);if(!s.data("blockUI.isBlocked")){o.each(function(){this.parentNode&&this.parentNode.removeChild(this)}),t&&t.el&&(t.el.style.display=t.display,t.el.style.position=t.position,t.parent&&t.parent.appendChild(t.el),s.removeData("blockUI.history")),s.data("blockUI.static")&&s.css("position","static"),"function"==typeof i.onUnblock&&i.onUnblock(n,i);var l=e(document.body),a=l.width(),d=l[0].style.width;l.width(a-1).width(a),l[0].style.width=d}}function n(o,t,i){var n=t==window,l=e(t);if((o||(!n||b)&&(n||l.data("blockUI.isBlocked")))&&(l.data("blockUI.isBlocked",o),n&&i.bindEvents&&(!o||i.showOverlay))){var a="mousedown mouseup keydown keypress keyup touchstart touchend touchmove";o?e(document).bind(a,i,s):e(document).unbind(a,s)}}function s(o){if("keydown"===o.type&&o.keyCode&&9==o.keyCode&&b&&o.data.constrainTabKey){var t=p,i=!o.shiftKey&&o.target===t[t.length-1],n=o.shiftKey&&o.target===t[0];if(i||n)return setTimeout(function(){l(n)},10),!1}var s=o.data,a=e(o.target);return a.hasClass("blockOverlay")&&s.onOverlayClick&&s.onOverlayClick(o),a.parents("div."+s.blockMsgClass).length>0?!0:0===a.parents().children().filter("div.blockUI").length}function l(e){if(p){var o=p[e===!0?p.length-1:0];o&&o.focus()}}function a(e,o,t){var i=e.parentNode,n=e.style,s=(i.offsetWidth-e.offsetWidth)/2-d(i,"borderLeftWidth"),l=(i.offsetHeight-e.offsetHeight)/2-d(i,"borderTopWidth");o&&(n.left=s>0?s+"px":"0"),t&&(n.top=l>0?l+"px":"0")}function d(o,t){return parseInt(e.css(o,t),10)||0}e.fn._fadeIn=e.fn.fadeIn;var c=e.noop||function(){},r=/MSIE/.test(navigator.userAgent),u=/MSIE 6.0/.test(navigator.userAgent)&&!/MSIE 8.0/.test(navigator.userAgent);document.documentMode||0;var f=e.isFunction(document.createElement("div").style.setExpression);e.blockUI=function(e){o(window,e)},e.unblockUI=function(e){t(window,e)},e.growlUI=function(o,t,i,n){var s=e('
      ');o&&s.append("

      "+o+"

      "),t&&s.append("

      "+t+"

      "),void 0===i&&(i=3e3);var l=function(o){o=o||{},e.blockUI({message:s,fadeIn:o.fadeIn!==void 0?o.fadeIn:700,fadeOut:o.fadeOut!==void 0?o.fadeOut:1e3,timeout:o.timeout!==void 0?o.timeout:i,centerY:!1,showOverlay:!1,onUnblock:n,css:e.blockUI.defaults.growlCSS})};l(),s.css("opacity"),s.mouseover(function(){l({fadeIn:0,timeout:3e4});var o=e(".blockMsg");o.stop(),o.fadeTo(300,1)}).mouseout(function(){e(".blockMsg").fadeOut(1e3)})},e.fn.block=function(t){if(this[0]===window)return e.blockUI(t),this;var i=e.extend({},e.blockUI.defaults,t||{});return this.each(function(){var o=e(this);i.ignoreIfBlocked&&o.data("blockUI.isBlocked")||o.unblock({fadeOut:0})}),this.each(function(){"static"==e.css(this,"position")&&(this.style.position="relative",e(this).data("blockUI.static",!0)),this.style.zoom=1,o(this,t)})},e.fn.unblock=function(o){return this[0]===window?(e.unblockUI(o),this):this.each(function(){t(this,o)})},e.blockUI.version=2.66,e.blockUI.defaults={message:"

      Please wait...

      ",title:null,draggable:!0,theme:!1,css:{padding:0,margin:0,width:"30%",top:"40%",left:"35%",textAlign:"center",color:"#000",border:"3px solid #aaa",backgroundColor:"#fff",cursor:"wait"},themedCSS:{width:"30%",top:"40%",left:"35%"},overlayCSS:{backgroundColor:"#000",opacity:.6,cursor:"wait"},cursorReset:"default",growlCSS:{width:"350px",top:"10px",left:"",right:"10px",border:"none",padding:"5px",opacity:.6,cursor:"default",color:"#fff",backgroundColor:"#000","-webkit-border-radius":"10px","-moz-border-radius":"10px","border-radius":"10px"},iframeSrc:/^https/i.test(window.location.href||"")?"javascript:false":"about:blank",forceIframe:!1,baseZ:1e3,centerX:!0,centerY:!0,allowBodyStretch:!0,bindEvents:!0,constrainTabKey:!0,fadeIn:200,fadeOut:400,timeout:0,showOverlay:!0,focusInput:!0,focusableElements:":input:enabled:visible",onBlock:null,onUnblock:null,onOverlayClick:null,quirksmodeOffsetHack:4,blockMsgClass:"blockMsg",ignoreIfBlocked:!1};var b=null,p=[]}"function"==typeof define&&define.amd&&define.amd.jQuery?define(["jquery"],e):e(jQuery)})(); \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js new file mode 100644 index 0000000000000..bd2dacb4eeebd --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js @@ -0,0 +1,18 @@ +/** + * Copyright (c) 2005 - 2010, James Auldridge + * All rights reserved. + * + * Licensed under the BSD, MIT, and GPL (your choice!) Licenses: + * http://code.google.com/p/cookies/wiki/License + * + */ +var jaaulde=window.jaaulde||{};jaaulde.utils=jaaulde.utils||{};jaaulde.utils.cookies=(function(){var resolveOptions,assembleOptionsString,parseCookies,constructor,defaultOptions={expiresAt:null,path:'/',domain:null,secure:false};resolveOptions=function(options){var returnValue,expireDate;if(typeof options!=='object'||options===null){returnValue=defaultOptions;}else +{returnValue={expiresAt:defaultOptions.expiresAt,path:defaultOptions.path,domain:defaultOptions.domain,secure:defaultOptions.secure};if(typeof options.expiresAt==='object'&&options.expiresAt instanceof Date){returnValue.expiresAt=options.expiresAt;}else if(typeof options.hoursToLive==='number'&&options.hoursToLive!==0){expireDate=new Date();expireDate.setTime(expireDate.getTime()+(options.hoursToLive*60*60*1000));returnValue.expiresAt=expireDate;}if(typeof options.path==='string'&&options.path!==''){returnValue.path=options.path;}if(typeof options.domain==='string'&&options.domain!==''){returnValue.domain=options.domain;}if(options.secure===true){returnValue.secure=options.secure;}}return returnValue;};assembleOptionsString=function(options){options=resolveOptions(options);return((typeof options.expiresAt==='object'&&options.expiresAt instanceof Date?'; expires='+options.expiresAt.toGMTString():'')+'; path='+options.path+(typeof options.domain==='string'?'; domain='+options.domain:'')+(options.secure===true?'; secure':''));};parseCookies=function(){var cookies={},i,pair,name,value,separated=document.cookie.split(';'),unparsedValue;for(i=0;i.sorting_1,table.dataTable.order-column tbody tr>.sorting_2,table.dataTable.order-column tbody tr>.sorting_3,table.dataTable.display tbody tr>.sorting_1,table.dataTable.display tbody tr>.sorting_2,table.dataTable.display tbody tr>.sorting_3{background-color:#f9f9f9}table.dataTable.order-column tbody tr.selected>.sorting_1,table.dataTable.order-column tbody tr.selected>.sorting_2,table.dataTable.order-column tbody tr.selected>.sorting_3,table.dataTable.display tbody tr.selected>.sorting_1,table.dataTable.display tbody tr.selected>.sorting_2,table.dataTable.display tbody tr.selected>.sorting_3{background-color:#acbad4}table.dataTable.display tbody tr.odd>.sorting_1,table.dataTable.order-column.stripe tbody tr.odd>.sorting_1{background-color:#f1f1f1}table.dataTable.display tbody tr.odd>.sorting_2,table.dataTable.order-column.stripe tbody tr.odd>.sorting_2{background-color:#f3f3f3}table.dataTable.display tbody tr.odd>.sorting_3,table.dataTable.order-column.stripe tbody tr.odd>.sorting_3{background-color:#f5f5f5}table.dataTable.display tbody tr.odd.selected>.sorting_1,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_1{background-color:#a6b3cd}table.dataTable.display tbody tr.odd.selected>.sorting_2,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_2{background-color:#a7b5ce}table.dataTable.display tbody tr.odd.selected>.sorting_3,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_3{background-color:#a9b6d0}table.dataTable.display tbody tr.even>.sorting_1,table.dataTable.order-column.stripe tbody tr.even>.sorting_1{background-color:#f9f9f9}table.dataTable.display tbody tr.even>.sorting_2,table.dataTable.order-column.stripe tbody tr.even>.sorting_2{background-color:#fbfbfb}table.dataTable.display tbody tr.even>.sorting_3,table.dataTable.order-column.stripe tbody tr.even>.sorting_3{background-color:#fdfdfd}table.dataTable.display tbody tr.even.selected>.sorting_1,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_1{background-color:#acbad4}table.dataTable.display tbody tr.even.selected>.sorting_2,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_2{background-color:#adbbd6}table.dataTable.display tbody tr.even.selected>.sorting_3,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_3{background-color:#afbdd8}table.dataTable.display tbody tr:hover>.sorting_1,table.dataTable.display tbody tr.odd:hover>.sorting_1,table.dataTable.display tbody tr.even:hover>.sorting_1,table.dataTable.order-column.hover tbody tr:hover>.sorting_1,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_1,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_1{background-color:#eaeaea}table.dataTable.display tbody tr:hover>.sorting_2,table.dataTable.display tbody tr.odd:hover>.sorting_2,table.dataTable.display tbody tr.even:hover>.sorting_2,table.dataTable.order-column.hover tbody tr:hover>.sorting_2,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_2,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_2{background-color:#ebebeb}table.dataTable.display tbody tr:hover>.sorting_3,table.dataTable.display tbody tr.odd:hover>.sorting_3,table.dataTable.display tbody tr.even:hover>.sorting_3,table.dataTable.order-column.hover tbody tr:hover>.sorting_3,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_3,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_3{background-color:#eee}table.dataTable.display tbody tr:hover.selected>.sorting_1,table.dataTable.display tbody tr.odd:hover.selected>.sorting_1,table.dataTable.display tbody tr.even:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_1{background-color:#a1aec7}table.dataTable.display tbody tr:hover.selected>.sorting_2,table.dataTable.display tbody tr.odd:hover.selected>.sorting_2,table.dataTable.display tbody tr.even:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_2{background-color:#a2afc8}table.dataTable.display tbody tr:hover.selected>.sorting_3,table.dataTable.display tbody tr.odd:hover.selected>.sorting_3,table.dataTable.display tbody tr.even:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_3{background-color:#a4b2cb}table.dataTable.no-footer{border-bottom:1px solid #111}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.compact thead th,table.dataTable.compact thead td{padding:5px 9px}table.dataTable.compact tfoot th,table.dataTable.compact tfoot td{padding:5px 9px 3px 9px}table.dataTable.compact tbody th,table.dataTable.compact tbody td{padding:4px 5px}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable,table.dataTable th,table.dataTable td{-webkit-box-sizing:content-box;-moz-box-sizing:content-box;box-sizing:content-box}.dataTables_wrapper{position:relative;clear:both;*zoom:1;zoom:1}.dataTables_wrapper .dataTables_length{float:left}.dataTables_wrapper .dataTables_filter{float:right;text-align:right}.dataTables_wrapper .dataTables_filter input{margin-left:0.5em}.dataTables_wrapper .dataTables_info{clear:both;float:left;padding-top:0.755em}.dataTables_wrapper .dataTables_paginate{float:right;text-align:right;padding-top:0.25em}.dataTables_wrapper .dataTables_paginate .paginate_button{box-sizing:border-box;display:inline-block;min-width:1.5em;padding:0.5em 1em;margin-left:2px;text-align:center;text-decoration:none !important;cursor:pointer;*cursor:hand;color:#333 !important;border:1px solid transparent}.dataTables_wrapper .dataTables_paginate .paginate_button.current,.dataTables_wrapper .dataTables_paginate .paginate_button.current:hover{color:#333 !important;border:1px solid #cacaca;background-color:#fff;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #fff), color-stop(100%, #dcdcdc));background:-webkit-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-moz-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-ms-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-o-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:linear-gradient(to bottom, #fff 0%, #dcdcdc 100%)}.dataTables_wrapper .dataTables_paginate .paginate_button.disabled,.dataTables_wrapper .dataTables_paginate .paginate_button.disabled:hover,.dataTables_wrapper .dataTables_paginate .paginate_button.disabled:active{cursor:default;color:#666 !important;border:1px solid transparent;background:transparent;box-shadow:none}.dataTables_wrapper .dataTables_paginate .paginate_button:hover{color:white !important;border:1px solid #111;background-color:#585858;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #585858), color-stop(100%, #111));background:-webkit-linear-gradient(top, #585858 0%, #111 100%);background:-moz-linear-gradient(top, #585858 0%, #111 100%);background:-ms-linear-gradient(top, #585858 0%, #111 100%);background:-o-linear-gradient(top, #585858 0%, #111 100%);background:linear-gradient(to bottom, #585858 0%, #111 100%)}.dataTables_wrapper .dataTables_paginate .paginate_button:active{outline:none;background-color:#2b2b2b;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #2b2b2b), color-stop(100%, #0c0c0c));background:-webkit-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-moz-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-ms-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-o-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:linear-gradient(to bottom, #2b2b2b 0%, #0c0c0c 100%);box-shadow:inset 0 0 3px #111}.dataTables_wrapper .dataTables_processing{position:absolute;top:50%;left:50%;width:100%;height:40px;margin-left:-50%;margin-top:-25px;padding-top:20px;text-align:center;font-size:1.2em;background-color:white;background:-webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0)));background:-webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%)}.dataTables_wrapper .dataTables_length,.dataTables_wrapper .dataTables_filter,.dataTables_wrapper .dataTables_info,.dataTables_wrapper .dataTables_processing,.dataTables_wrapper .dataTables_paginate{color:#333}.dataTables_wrapper .dataTables_scroll{clear:both}.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody{*margin-top:-1px;-webkit-overflow-scrolling:touch}.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody th>div.dataTables_sizing,.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody td>div.dataTables_sizing{height:0;overflow:hidden;margin:0 !important;padding:0 !important}.dataTables_wrapper.no-footer .dataTables_scrollBody{border-bottom:1px solid #111}.dataTables_wrapper.no-footer div.dataTables_scrollHead table,.dataTables_wrapper.no-footer div.dataTables_scrollBody table{border-bottom:none}.dataTables_wrapper:after{visibility:hidden;display:block;content:"";clear:both;height:0}@media screen and (max-width: 767px){.dataTables_wrapper .dataTables_info,.dataTables_wrapper .dataTables_paginate{float:none;text-align:center}.dataTables_wrapper .dataTables_paginate{margin-top:0.5em}}@media screen and (max-width: 640px){.dataTables_wrapper .dataTables_length,.dataTables_wrapper .dataTables_filter{float:none;text-align:center}.dataTables_wrapper .dataTables_filter{margin-top:0.5em}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js new file mode 100644 index 0000000000000..8885017c35d0d --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js @@ -0,0 +1,157 @@ +/*! DataTables 1.10.4 + * ©2008-2014 SpryMedia Ltd - datatables.net/license + */ +(function(Da,P,l){var O=function(g){function V(a){var b,c,e={};g.each(a,function(d){if((b=d.match(/^([^A-Z]+?)([A-Z])/))&&-1!=="a aa ai ao as b fn i m o s ".indexOf(b[1]+" "))c=d.replace(b[0],b[2].toLowerCase()),e[c]=d,"o"===b[1]&&V(a[d])});a._hungarianMap=e}function G(a,b,c){a._hungarianMap||V(a);var e;g.each(b,function(d){e=a._hungarianMap[d];if(e!==l&&(c||b[e]===l))"o"===e.charAt(0)?(b[e]||(b[e]={}),g.extend(!0,b[e],b[d]),G(a[e],b[e],c)):b[e]=b[d]})}function O(a){var b=p.defaults.oLanguage,c=a.sZeroRecords; +!a.sEmptyTable&&(c&&"No data available in table"===b.sEmptyTable)&&D(a,a,"sZeroRecords","sEmptyTable");!a.sLoadingRecords&&(c&&"Loading..."===b.sLoadingRecords)&&D(a,a,"sZeroRecords","sLoadingRecords");a.sInfoThousands&&(a.sThousands=a.sInfoThousands);(a=a.sDecimal)&&cb(a)}function db(a){z(a,"ordering","bSort");z(a,"orderMulti","bSortMulti");z(a,"orderClasses","bSortClasses");z(a,"orderCellsTop","bSortCellsTop");z(a,"order","aaSorting");z(a,"orderFixed","aaSortingFixed");z(a,"paging","bPaginate"); +z(a,"pagingType","sPaginationType");z(a,"pageLength","iDisplayLength");z(a,"searching","bFilter");if(a=a.aoSearchCols)for(var b=0,c=a.length;b").css({position:"absolute",top:0,left:0,height:1,width:1,overflow:"hidden"}).append(g("
      ").css({position:"absolute",top:1,left:1,width:100, +overflow:"scroll"}).append(g('
      ').css({width:"100%",height:10}))).appendTo("body"),c=b.find(".test");a.bScrollOversize=100===c[0].offsetWidth;a.bScrollbarLeft=1!==c.offset().left;b.remove()}function gb(a,b,c,e,d,f){var h,i=!1;c!==l&&(h=c,i=!0);for(;e!==d;)a.hasOwnProperty(e)&&(h=i?b(h,a[e],e,a):a[e],i=!0,e+=f);return h}function Ea(a,b){var c=p.defaults.column,e=a.aoColumns.length,c=g.extend({},p.models.oColumn,c,{nTh:b?b:P.createElement("th"),sTitle:c.sTitle?c.sTitle:b?b.innerHTML: +"",aDataSort:c.aDataSort?c.aDataSort:[e],mData:c.mData?c.mData:e,idx:e});a.aoColumns.push(c);c=a.aoPreSearchCols;c[e]=g.extend({},p.models.oSearch,c[e]);ja(a,e,null)}function ja(a,b,c){var b=a.aoColumns[b],e=a.oClasses,d=g(b.nTh);if(!b.sWidthOrig){b.sWidthOrig=d.attr("width")||null;var f=(d.attr("style")||"").match(/width:\s*(\d+[pxem%]+)/);f&&(b.sWidthOrig=f[1])}c!==l&&null!==c&&(eb(c),G(p.defaults.column,c),c.mDataProp!==l&&!c.mData&&(c.mData=c.mDataProp),c.sType&&(b._sManualType=c.sType),c.className&& +!c.sClass&&(c.sClass=c.className),g.extend(b,c),D(b,c,"sWidth","sWidthOrig"),"number"===typeof c.iDataSort&&(b.aDataSort=[c.iDataSort]),D(b,c,"aDataSort"));var h=b.mData,i=W(h),j=b.mRender?W(b.mRender):null,c=function(a){return"string"===typeof a&&-1!==a.indexOf("@")};b._bAttrSrc=g.isPlainObject(h)&&(c(h.sort)||c(h.type)||c(h.filter));b.fnGetData=function(a,b,c){var e=i(a,b,l,c);return j&&b?j(e,b,a,c):e};b.fnSetData=function(a,b,c){return Q(h)(a,b,c)};"number"!==typeof h&&(a._rowReadObject=!0);a.oFeatures.bSort|| +(b.bSortable=!1,d.addClass(e.sSortableNone));a=-1!==g.inArray("asc",b.asSorting);c=-1!==g.inArray("desc",b.asSorting);!b.bSortable||!a&&!c?(b.sSortingClass=e.sSortableNone,b.sSortingClassJUI=""):a&&!c?(b.sSortingClass=e.sSortableAsc,b.sSortingClassJUI=e.sSortJUIAscAllowed):!a&&c?(b.sSortingClass=e.sSortableDesc,b.sSortingClassJUI=e.sSortJUIDescAllowed):(b.sSortingClass=e.sSortable,b.sSortingClassJUI=e.sSortJUI)}function X(a){if(!1!==a.oFeatures.bAutoWidth){var b=a.aoColumns;Fa(a);for(var c=0,e=b.length;c< +e;c++)b[c].nTh.style.width=b[c].sWidth}b=a.oScroll;(""!==b.sY||""!==b.sX)&&Y(a);u(a,null,"column-sizing",[a])}function ka(a,b){var c=Z(a,"bVisible");return"number"===typeof c[b]?c[b]:null}function $(a,b){var c=Z(a,"bVisible"),c=g.inArray(b,c);return-1!==c?c:null}function aa(a){return Z(a,"bVisible").length}function Z(a,b){var c=[];g.map(a.aoColumns,function(a,d){a[b]&&c.push(d)});return c}function Ga(a){var b=a.aoColumns,c=a.aoData,e=p.ext.type.detect,d,f,h,i,j,g,m,o,k;d=0;for(f=b.length;do[f])e(m.length+o[f],n);else if("string"===typeof o[f]){i=0;for(j=m.length;ib&&a[d]--; -1!=e&&c===l&& +a.splice(e,1)}function ca(a,b,c,e){var d=a.aoData[b],f,h=function(c,f){for(;c.childNodes.length;)c.removeChild(c.firstChild);c.innerHTML=v(a,b,f,"display")};if("dom"===c||(!c||"auto"===c)&&"dom"===d.src)d._aData=ma(a,d,e,e===l?l:d._aData).data;else{var i=d.anCells;if(i)if(e!==l)h(i[e],e);else{c=0;for(f=i.length;c").appendTo(h));b=0;for(c= +m.length;btr").attr("role","row");g(h).find(">tr>th, >tr>td").addClass(n.sHeaderTH);g(i).find(">tr>th, >tr>td").addClass(n.sFooterTH);if(null!==i){a=a.aoFooter[0];b=0;for(c=a.length;b=a.fnRecordsDisplay()?0:h,a.iInitDisplayStart=-1);var h=a._iDisplayStart,n=a.fnDisplayEnd();if(a.bDeferLoading)a.bDeferLoading= +!1,a.iDraw++,B(a,!1);else if(i){if(!a.bDestroying&&!jb(a))return}else a.iDraw++;if(0!==j.length){f=i?a.aoData.length:n;for(i=i?0:h;i",{"class":d? +e[0]:""}).append(g("",{valign:"top",colSpan:aa(a),"class":a.oClasses.sRowEmpty}).html(c))[0];u(a,"aoHeaderCallback","header",[g(a.nTHead).children("tr")[0],Ka(a),h,n,j]);u(a,"aoFooterCallback","footer",[g(a.nTFoot).children("tr")[0],Ka(a),h,n,j]);e=g(a.nTBody);e.children().detach();e.append(g(b));u(a,"aoDrawCallback","draw",[a]);a.bSorted=!1;a.bFiltered=!1;a.bDrawing=!1}}function M(a,b){var c=a.oFeatures,e=c.bFilter;c.bSort&&kb(a);e?fa(a,a.oPreviousSearch):a.aiDisplay=a.aiDisplayMaster.slice(); +!0!==b&&(a._iDisplayStart=0);a._drawHold=b;L(a);a._drawHold=!1}function lb(a){var b=a.oClasses,c=g(a.nTable),c=g("
      ").insertBefore(c),e=a.oFeatures,d=g("
      ",{id:a.sTableId+"_wrapper","class":b.sWrapper+(a.nTFoot?"":" "+b.sNoFooter)});a.nHolding=c[0];a.nTableWrapper=d[0];a.nTableReinsertBefore=a.nTable.nextSibling;for(var f=a.sDom.split(""),h,i,j,n,m,o,k=0;k")[0];n=f[k+1];if("'"==n||'"'==n){m="";for(o=2;f[k+o]!=n;)m+=f[k+o],o++;"H"==m?m=b.sJUIHeader: +"F"==m&&(m=b.sJUIFooter);-1!=m.indexOf(".")?(n=m.split("."),j.id=n[0].substr(1,n[0].length-1),j.className=n[1]):"#"==m.charAt(0)?j.id=m.substr(1,m.length-1):j.className=m;k+=o}d.append(j);d=g(j)}else if(">"==i)d=d.parent();else if("l"==i&&e.bPaginate&&e.bLengthChange)h=mb(a);else if("f"==i&&e.bFilter)h=nb(a);else if("r"==i&&e.bProcessing)h=ob(a);else if("t"==i)h=pb(a);else if("i"==i&&e.bInfo)h=qb(a);else if("p"==i&&e.bPaginate)h=rb(a);else if(0!==p.ext.feature.length){j=p.ext.feature;o=0;for(n=j.length;o< +n;o++)if(i==j[o].cFeature){h=j[o].fnInit(a);break}}h&&(j=a.aanFeatures,j[i]||(j[i]=[]),j[i].push(h),d.append(h))}c.replaceWith(d)}function da(a,b){var c=g(b).children("tr"),e,d,f,h,i,j,n,m,o,k;a.splice(0,a.length);f=0;for(j=c.length;f',i=e.sSearch,i=i.match(/_INPUT_/)?i.replace("_INPUT_",h):i+h,b=g("
      ",{id:!f.f?c+"_filter":null,"class":b.sFilter}).append(g("
      ").addClass(b.sLength);a.aanFeatures.l||(j[0].id=c+"_length");j.children().append(a.oLanguage.sLengthMenu.replace("_MENU_",d[0].outerHTML));g("select",j).val(a._iDisplayLength).bind("change.DT", +function(){Qa(a,g(this).val());L(a)});g(a.nTable).bind("length.dt.DT",function(b,c,f){a===c&&g("select",j).val(f)});return j[0]}function rb(a){var b=a.sPaginationType,c=p.ext.pager[b],e="function"===typeof c,d=function(a){L(a)},b=g("
      ").addClass(a.oClasses.sPaging+b)[0],f=a.aanFeatures;e||c.fnInit(a,b,d);f.p||(b.id=a.sTableId+"_paginate",a.aoDrawCallback.push({fn:function(a){if(e){var b=a._iDisplayStart,g=a._iDisplayLength,n=a.fnRecordsDisplay(),m=-1===g,b=m?0:Math.ceil(b/g),g=m?1:Math.ceil(n/ +g),n=c(b,g),o,m=0;for(o=f.p.length;mf&&(e=0)):"first"==b?e=0:"previous"==b?(e=0<=d?e-d:0,0>e&&(e=0)):"next"==b?e+d",{id:!a.aanFeatures.r?a.sTableId+"_processing":null,"class":a.oClasses.sProcessing}).html(a.oLanguage.sProcessing).insertBefore(a.nTable)[0]}function B(a,b){a.oFeatures.bProcessing&&g(a.aanFeatures.r).css("display",b?"block":"none");u(a,null,"processing",[a,b])}function pb(a){var b=g(a.nTable);b.attr("role","grid");var c=a.oScroll;if(""===c.sX&&""===c.sY)return a.nTable;var e=c.sX,d=c.sY,f=a.oClasses,h=b.children("caption"),i=h.length?h[0]._captionSide:null, +j=g(b[0].cloneNode(!1)),n=g(b[0].cloneNode(!1)),m=b.children("tfoot");c.sX&&"100%"===b.attr("width")&&b.removeAttr("width");m.length||(m=null);c=g("
      ",{"class":f.sScrollWrapper}).append(g("
      ",{"class":f.sScrollHead}).css({overflow:"hidden",position:"relative",border:0,width:e?!e?null:s(e):"100%"}).append(g("
      ",{"class":f.sScrollHeadInner}).css({"box-sizing":"content-box",width:c.sXInner||"100%"}).append(j.removeAttr("id").css("margin-left",0).append("top"===i?h:null).append(b.children("thead"))))).append(g("
      ", +{"class":f.sScrollBody}).css({overflow:"auto",height:!d?null:s(d),width:!e?null:s(e)}).append(b));m&&c.append(g("
      ",{"class":f.sScrollFoot}).css({overflow:"hidden",border:0,width:e?!e?null:s(e):"100%"}).append(g("
      ",{"class":f.sScrollFootInner}).append(n.removeAttr("id").css("margin-left",0).append("bottom"===i?h:null).append(b.children("tfoot")))));var b=c.children(),o=b[0],f=b[1],k=m?b[2]:null;e&&g(f).scroll(function(){var a=this.scrollLeft;o.scrollLeft=a;m&&(k.scrollLeft=a)});a.nScrollHead= +o;a.nScrollBody=f;a.nScrollFoot=k;a.aoDrawCallback.push({fn:Y,sName:"scrolling"});return c[0]}function Y(a){var b=a.oScroll,c=b.sX,e=b.sXInner,d=b.sY,f=b.iBarWidth,h=g(a.nScrollHead),i=h[0].style,j=h.children("div"),n=j[0].style,m=j.children("table"),j=a.nScrollBody,o=g(j),k=j.style,l=g(a.nScrollFoot).children("div"),p=l.children("table"),r=g(a.nTHead),q=g(a.nTable),t=q[0],N=t.style,J=a.nTFoot?g(a.nTFoot):null,u=a.oBrowser,w=u.bScrollOversize,y,v,x,K,z,A=[],B=[],C=[],D,E=function(a){a=a.style;a.paddingTop= +"0";a.paddingBottom="0";a.borderTopWidth="0";a.borderBottomWidth="0";a.height=0};q.children("thead, tfoot").remove();z=r.clone().prependTo(q);y=r.find("tr");x=z.find("tr");z.find("th, td").removeAttr("tabindex");J&&(K=J.clone().prependTo(q),v=J.find("tr"),K=K.find("tr"));c||(k.width="100%",h[0].style.width="100%");g.each(pa(a,z),function(b,c){D=ka(a,b);c.style.width=a.aoColumns[D].sWidth});J&&F(function(a){a.style.width=""},K);b.bCollapse&&""!==d&&(k.height=o[0].offsetHeight+r[0].offsetHeight+"px"); +h=q.outerWidth();if(""===c){if(N.width="100%",w&&(q.find("tbody").height()>j.offsetHeight||"scroll"==o.css("overflow-y")))N.width=s(q.outerWidth()-f)}else""!==e?N.width=s(e):h==o.width()&&o.height()h-f&&(N.width=s(h))):N.width=s(h);h=q.outerWidth();F(E,x);F(function(a){C.push(a.innerHTML);A.push(s(g(a).css("width")))},x);F(function(a,b){a.style.width=A[b]},y);g(x).height(0);J&&(F(E,K),F(function(a){B.push(s(g(a).css("width")))},K),F(function(a,b){a.style.width= +B[b]},v),g(K).height(0));F(function(a,b){a.innerHTML='
      '+C[b]+"
      ";a.style.width=A[b]},x);J&&F(function(a,b){a.innerHTML="";a.style.width=B[b]},K);if(q.outerWidth()j.offsetHeight||"scroll"==o.css("overflow-y")?h+f:h;if(w&&(j.scrollHeight>j.offsetHeight||"scroll"==o.css("overflow-y")))N.width=s(v-f);(""===c||""!==e)&&R(a,1,"Possible column misalignment",6)}else v="100%";k.width=s(v);i.width=s(v);J&&(a.nScrollFoot.style.width= +s(v));!d&&w&&(k.height=s(t.offsetHeight+f));d&&b.bCollapse&&(k.height=s(d),b=c&&t.offsetWidth>j.offsetWidth?f:0,t.offsetHeightj.clientHeight||"scroll"==o.css("overflow-y");u="padding"+(u.bScrollbarLeft?"Left":"Right");n[u]=m?f+"px":"0px";J&&(p[0].style.width=s(b),l[0].style.width=s(b),l[0].style[u]=m?f+"px":"0px");o.scroll();if((a.bSorted||a.bFiltered)&&!a._drawHold)j.scrollTop=0}function F(a, +b,c){for(var e=0,d=0,f=b.length,h,g;d"));i.find("tfoot th, tfoot td").css("width","");var p=i.find("tbody tr"),j=pa(a,i.find("thead")[0]);for(k=0;k").css("width",s(a)).appendTo(b||P.body),e=c[0].offsetWidth;c.remove();return e}function Eb(a,b){var c=a.oScroll;if(c.sX||c.sY)c=!c.sX?c.iBarWidth:0,b.style.width=s(g(b).outerWidth()-c)}function Db(a,b){var c=Fb(a,b);if(0>c)return null; +var e=a.aoData[c];return!e.nTr?g("").html(v(a,c,b,"display"))[0]:e.anCells[b]}function Fb(a,b){for(var c,e=-1,d=-1,f=0,h=a.aoData.length;fe&&(e=c.length,d=f);return d}function s(a){return null===a?"0px":"number"==typeof a?0>a?"0px":a+"px":a.match(/\d$/)?a+"px":a}function Gb(){if(!p.__scrollbarWidth){var a=g("

      ").css({width:"100%",height:200,padding:0})[0],b=g("

      ").css({position:"absolute",top:0,left:0,width:200,height:150,padding:0, +overflow:"hidden",visibility:"hidden"}).append(a).appendTo("body"),c=a.offsetWidth;b.css("overflow","scroll");a=a.offsetWidth;c===a&&(a=b[0].clientWidth);b.remove();p.__scrollbarWidth=c-a}return p.__scrollbarWidth}function T(a){var b,c,e=[],d=a.aoColumns,f,h,i,j;b=a.aaSortingFixed;c=g.isPlainObject(b);var n=[];f=function(a){a.length&&!g.isArray(a[0])?n.push(a):n.push.apply(n,a)};g.isArray(b)&&f(b);c&&b.pre&&f(b.pre);f(a.aaSorting);c&&b.post&&f(b.post);for(a=0;ad?1:0,0!==c)return"asc"===g.dir?c:-c;c=e[a];d=e[b];return cd?1:0}):j.sort(function(a,b){var c,h,g,i,j=n.length,l=f[a]._aSortData,p=f[b]._aSortData;for(g=0;gh?1:0})}a.bSorted=!0}function Ib(a){for(var b,c,e=a.aoColumns,d=T(a),a=a.oLanguage.oAria,f=0,h=e.length;f/g,"");var j=c.nTh;j.removeAttribute("aria-sort");c.bSortable&&(0d?d+1:3));d=0;for(f=e.length;dd?d+1:3))}a.aLastSort=e}function Hb(a,b){var c=a.aoColumns[b],e=p.ext.order[c.sSortDataType],d;e&&(d=e.call(a.oInstance,a,b,$(a,b)));for(var f,h=p.ext.type.order[c.sType+"-pre"],g=0,j=a.aoData.length;g< +j;g++)if(c=a.aoData[g],c._aSortData||(c._aSortData=[]),!c._aSortData[b]||e)f=e?d[g]:v(a,g,b,"sort"),c._aSortData[b]=h?h(f):f}function xa(a){if(a.oFeatures.bStateSave&&!a.bDestroying){var b={time:+new Date,start:a._iDisplayStart,length:a._iDisplayLength,order:g.extend(!0,[],a.aaSorting),search:yb(a.oPreviousSearch),columns:g.map(a.aoColumns,function(b,e){return{visible:b.bVisible,search:yb(a.aoPreSearchCols[e])}})};u(a,"aoStateSaveParams","stateSaveParams",[a,b]);a.oSavedState=b;a.fnStateSaveCallback.call(a.oInstance, +a,b)}}function Jb(a){var b,c,e=a.aoColumns;if(a.oFeatures.bStateSave){var d=a.fnStateLoadCallback.call(a.oInstance,a);if(d&&d.time&&(b=u(a,"aoStateLoadParams","stateLoadParams",[a,d]),-1===g.inArray(!1,b)&&(b=a.iStateDuration,!(0=e.length?[0,c[1]]:c)});g.extend(a.oPreviousSearch, +zb(d.search));b=0;for(c=d.columns.length;b=c&&(b=c-e);b-=b%e;if(-1===e||0>b)b=0;a._iDisplayStart=b}function Oa(a,b){var c=a.renderer,e=p.ext.renderer[b];return g.isPlainObject(c)&& +c[b]?e[c[b]]||e._:"string"===typeof c?e[c]||e._:e._}function A(a){return a.oFeatures.bServerSide?"ssp":a.ajax||a.sAjaxSource?"ajax":"dom"}function Va(a,b){var c=[],c=Lb.numbers_length,e=Math.floor(c/2);b<=c?c=U(0,b):a<=e?(c=U(0,c-2),c.push("ellipsis"),c.push(b-1)):(a>=b-1-e?c=U(b-(c-2),b):(c=U(a-1,a+2),c.push("ellipsis"),c.push(b-1)),c.splice(0,0,"ellipsis"),c.splice(0,0,0));c.DT_el="span";return c}function cb(a){g.each({num:function(b){return za(b,a)},"num-fmt":function(b){return za(b,a,Wa)},"html-num":function(b){return za(b, +a,Aa)},"html-num-fmt":function(b){return za(b,a,Aa,Wa)}},function(b,c){w.type.order[b+a+"-pre"]=c;b.match(/^html\-/)&&(w.type.search[b+a]=w.type.search.html)})}function Mb(a){return function(){var b=[ya(this[p.ext.iApiIndex])].concat(Array.prototype.slice.call(arguments));return p.ext.internal[a].apply(this,b)}}var p,w,q,r,t,Xa={},Nb=/[\r\n]/g,Aa=/<.*?>/g,$b=/^[\w\+\-]/,ac=/[\w\+\-]$/,Xb=RegExp("(\\/|\\.|\\*|\\+|\\?|\\||\\(|\\)|\\[|\\]|\\{|\\}|\\\\|\\$|\\^|\\-)","g"),Wa=/[',$\u00a3\u20ac\u00a5%\u2009\u202F]/g, +H=function(a){return!a||!0===a||"-"===a?!0:!1},Ob=function(a){var b=parseInt(a,10);return!isNaN(b)&&isFinite(a)?b:null},Pb=function(a,b){Xa[b]||(Xa[b]=RegExp(ua(b),"g"));return"string"===typeof a&&"."!==b?a.replace(/\./g,"").replace(Xa[b],"."):a},Ya=function(a,b,c){var e="string"===typeof a;b&&e&&(a=Pb(a,b));c&&e&&(a=a.replace(Wa,""));return H(a)||!isNaN(parseFloat(a))&&isFinite(a)},Qb=function(a,b,c){return H(a)?!0:!(H(a)||"string"===typeof a)?null:Ya(a.replace(Aa,""),b,c)?!0:null},C=function(a, +b,c){var e=[],d=0,f=a.length;if(c!==l)for(;d")[0],Yb=va.textContent!==l,Zb=/<.*?>/g;p=function(a){this.$=function(a,b){return this.api(!0).$(a,b)};this._=function(a,b){return this.api(!0).rows(a,b).data()};this.api=function(a){return a?new q(ya(this[w.iApiIndex])):new q(this)};this.fnAddData=function(a,b){var c=this.api(!0),e=g.isArray(a)&&(g.isArray(a[0])||g.isPlainObject(a[0]))? +c.rows.add(a):c.row.add(a);(b===l||b)&&c.draw();return e.flatten().toArray()};this.fnAdjustColumnSizing=function(a){var b=this.api(!0).columns.adjust(),c=b.settings()[0],e=c.oScroll;a===l||a?b.draw(!1):(""!==e.sX||""!==e.sY)&&Y(c)};this.fnClearTable=function(a){var b=this.api(!0).clear();(a===l||a)&&b.draw()};this.fnClose=function(a){this.api(!0).row(a).child.hide()};this.fnDeleteRow=function(a,b,c){var e=this.api(!0),a=e.rows(a),d=a.settings()[0],g=d.aoData[a[0][0]];a.remove();b&&b.call(this,d,g); +(c===l||c)&&e.draw();return g};this.fnDestroy=function(a){this.api(!0).destroy(a)};this.fnDraw=function(a){this.api(!0).draw(!a)};this.fnFilter=function(a,b,c,e,d,g){d=this.api(!0);null===b||b===l?d.search(a,c,e,g):d.column(b).search(a,c,e,g);d.draw()};this.fnGetData=function(a,b){var c=this.api(!0);if(a!==l){var e=a.nodeName?a.nodeName.toLowerCase():"";return b!==l||"td"==e||"th"==e?c.cell(a,b).data():c.row(a).data()||null}return c.data().toArray()};this.fnGetNodes=function(a){var b=this.api(!0); +return a!==l?b.row(a).node():b.rows().nodes().flatten().toArray()};this.fnGetPosition=function(a){var b=this.api(!0),c=a.nodeName.toUpperCase();return"TR"==c?b.row(a).index():"TD"==c||"TH"==c?(a=b.cell(a).index(),[a.row,a.columnVisible,a.column]):null};this.fnIsOpen=function(a){return this.api(!0).row(a).child.isShown()};this.fnOpen=function(a,b,c){return this.api(!0).row(a).child(b,c).show().child()[0]};this.fnPageChange=function(a,b){var c=this.api(!0).page(a);(b===l||b)&&c.draw(!1)};this.fnSetColumnVis= +function(a,b,c){a=this.api(!0).column(a).visible(b);(c===l||c)&&a.columns.adjust().draw()};this.fnSettings=function(){return ya(this[w.iApiIndex])};this.fnSort=function(a){this.api(!0).order(a).draw()};this.fnSortListener=function(a,b,c){this.api(!0).order.listener(a,b,c)};this.fnUpdate=function(a,b,c,e,d){var g=this.api(!0);c===l||null===c?g.row(b).data(a):g.cell(b,c).data(a);(d===l||d)&&g.columns.adjust();(e===l||e)&&g.draw();return 0};this.fnVersionCheck=w.fnVersionCheck;var b=this,c=a===l,e=this.length; +c&&(a={});this.oApi=this.internal=w.internal;for(var d in p.ext.internal)d&&(this[d]=Mb(d));this.each(function(){var d={},d=1t<"F"ip>'),k.renderer)? +g.isPlainObject(k.renderer)&&!k.renderer.header&&(k.renderer.header="jqueryui"):k.renderer="jqueryui":g.extend(j,p.ext.classes,d.oClasses);g(this).addClass(j.sTable);if(""!==k.oScroll.sX||""!==k.oScroll.sY)k.oScroll.iBarWidth=Gb();!0===k.oScroll.sX&&(k.oScroll.sX="100%");k.iInitDisplayStart===l&&(k.iInitDisplayStart=d.iDisplayStart,k._iDisplayStart=d.iDisplayStart);null!==d.iDeferLoading&&(k.bDeferLoading=!0,h=g.isArray(d.iDeferLoading),k._iRecordsDisplay=h?d.iDeferLoading[0]:d.iDeferLoading,k._iRecordsTotal= +h?d.iDeferLoading[1]:d.iDeferLoading);var r=k.oLanguage;g.extend(!0,r,d.oLanguage);""!==r.sUrl&&(g.ajax({dataType:"json",url:r.sUrl,success:function(a){O(a);G(m.oLanguage,a);g.extend(true,r,a);ga(k)},error:function(){ga(k)}}),n=!0);null===d.asStripeClasses&&(k.asStripeClasses=[j.sStripeOdd,j.sStripeEven]);var h=k.asStripeClasses,q=g("tbody tr:eq(0)",this);-1!==g.inArray(!0,g.map(h,function(a){return q.hasClass(a)}))&&(g("tbody tr",this).removeClass(h.join(" ")),k.asDestroyStripes=h.slice());var o= +[],s,h=this.getElementsByTagName("thead");0!==h.length&&(da(k.aoHeader,h[0]),o=pa(k));if(null===d.aoColumns){s=[];h=0;for(i=o.length;h").appendTo(this));k.nTHead=i[0];i=g(this).children("tbody");0===i.length&&(i=g("").appendTo(this));k.nTBody=i[0];i=g(this).children("tfoot");if(0===i.length&&0").appendTo(this);0===i.length||0===i.children().length?g(this).addClass(j.sNoFooter): +0a?new q(b[a],this[a]):null},filter:function(a){var b=[];if(y.filter)b=y.filter.call(this,a,this);else for(var c=0,e=this.length;c").addClass(b);g("td",c).addClass(b).html(a)[0].colSpan=aa(e);d.push(c[0])}};if(g.isArray(a)||a instanceof g)for(var h=0,i=a.length;h=0?b:h.length+b];if(typeof a==="function"){var d=Ba(c,f);return g.map(h,function(b,f){return a(f,Vb(c,f,0,0,d),j[f])?f:null})}var k=typeof a==="string"?a.match(cc):"";if(k)switch(k[2]){case "visIdx":case "visible":b= +parseInt(k[1],10);if(b<0){var l=g.map(h,function(a,b){return a.bVisible?b:null});return[l[l.length+b]]}return[ka(c,b)];case "name":return g.map(i,function(a,b){return a===k[1]?b:null})}else return g(j).filter(a).map(function(){return g.inArray(this,j)}).toArray()})},1);c.selector.cols=a;c.selector.opts=b;return c});t("columns().header()","column().header()",function(){return this.iterator("column",function(a,b){return a.aoColumns[b].nTh},1)});t("columns().footer()","column().footer()",function(){return this.iterator("column", +function(a,b){return a.aoColumns[b].nTf},1)});t("columns().data()","column().data()",function(){return this.iterator("column-rows",Vb,1)});t("columns().dataSrc()","column().dataSrc()",function(){return this.iterator("column",function(a,b){return a.aoColumns[b].mData},1)});t("columns().cache()","column().cache()",function(a){return this.iterator("column-rows",function(b,c,e,d,f){return ha(b.aoData,f,"search"===a?"_aFilterData":"_aSortData",c)},1)});t("columns().nodes()","column().nodes()",function(){return this.iterator("column-rows", +function(a,b,c,e,d){return ha(a.aoData,d,"anCells",b)},1)});t("columns().visible()","column().visible()",function(a,b){return this.iterator("column",function(c,e){if(a===l)return c.aoColumns[e].bVisible;var d=c.aoColumns,f=d[e],h=c.aoData,i,j,n;if(a!==l&&f.bVisible!==a){if(a){var m=g.inArray(!0,C(d,"bVisible"),e+1);i=0;for(j=h.length;ie;return!0};p.isDataTable=p.fnIsDataTable=function(a){var b=g(a).get(0),c=!1;g.each(p.settings,function(a,d){if(d.nTable===b||d.nScrollHead===b||d.nScrollFoot===b)c=!0});return c};p.tables=p.fnTables=function(a){return g.map(p.settings,function(b){if(!a||a&&g(b.nTable).is(":visible"))return b.nTable})};p.util={throttle:ta,escapeRegex:ua}; +p.camelToHungarian=G;r("$()",function(a,b){var c=this.rows(b).nodes(),c=g(c);return g([].concat(c.filter(a).toArray(),c.find(a).toArray()))});g.each(["on","one","off"],function(a,b){r(b+"()",function(){var a=Array.prototype.slice.call(arguments);a[0].match(/\.dt\b/)||(a[0]+=".dt");var e=g(this.tables().nodes());e[b].apply(e,a);return this})});r("clear()",function(){return this.iterator("table",function(a){na(a)})});r("settings()",function(){return new q(this.context,this.context)});r("data()",function(){return this.iterator("table", +function(a){return C(a.aoData,"_aData")}).flatten()});r("destroy()",function(a){a=a||!1;return this.iterator("table",function(b){var c=b.nTableWrapper.parentNode,e=b.oClasses,d=b.nTable,f=b.nTBody,h=b.nTHead,i=b.nTFoot,j=g(d),f=g(f),l=g(b.nTableWrapper),m=g.map(b.aoData,function(a){return a.nTr}),o;b.bDestroying=!0;u(b,"aoDestroyCallback","destroy",[b]);a||(new q(b)).columns().visible(!0);l.unbind(".DT").find(":not(tbody *)").unbind(".DT");g(Da).unbind(".DT-"+b.sInstance);d!=h.parentNode&&(j.children("thead").detach(), +j.append(h));i&&d!=i.parentNode&&(j.children("tfoot").detach(),j.append(i));j.detach();l.detach();b.aaSorting=[];b.aaSortingFixed=[];wa(b);g(m).removeClass(b.asStripeClasses.join(" "));g("th, td",h).removeClass(e.sSortable+" "+e.sSortableAsc+" "+e.sSortableDesc+" "+e.sSortableNone);b.bJUI&&(g("th span."+e.sSortIcon+", td span."+e.sSortIcon,h).detach(),g("th, td",h).each(function(){var a=g("div."+e.sSortJUIWrapper,this);g(this).append(a.contents());a.detach()}));!a&&c&&c.insertBefore(d,b.nTableReinsertBefore); +f.children().detach();f.append(m);j.css("width",b.sDestroyWidth).removeClass(e.sTable);(o=b.asDestroyStripes.length)&&f.children().each(function(a){g(this).addClass(b.asDestroyStripes[a%o])});c=g.inArray(b,p.settings);-1!==c&&p.settings.splice(c,1)})});p.version="1.10.4";p.settings=[];p.models={};p.models.oSearch={bCaseInsensitive:!0,sSearch:"",bRegex:!1,bSmart:!0};p.models.oRow={nTr:null,anCells:null,_aData:[],_aSortData:null,_aFilterData:null,_sFilterRow:null,_sRowStripe:"",src:null};p.models.oColumn= +{idx:null,aDataSort:null,asSorting:null,bSearchable:null,bSortable:null,bVisible:null,_sManualType:null,_bAttrSrc:!1,fnCreatedCell:null,fnGetData:null,fnSetData:null,mData:null,mRender:null,nTh:null,nTf:null,sClass:null,sContentPadding:null,sDefaultContent:null,sName:null,sSortDataType:"std",sSortingClass:null,sSortingClassJUI:null,sTitle:null,sType:null,sWidth:null,sWidthOrig:null};p.defaults={aaData:null,aaSorting:[[0,"asc"]],aaSortingFixed:[],ajax:null,aLengthMenu:[10,25,50,100],aoColumns:null, +aoColumnDefs:null,aoSearchCols:[],asStripeClasses:null,bAutoWidth:!0,bDeferRender:!1,bDestroy:!1,bFilter:!0,bInfo:!0,bJQueryUI:!1,bLengthChange:!0,bPaginate:!0,bProcessing:!1,bRetrieve:!1,bScrollCollapse:!1,bServerSide:!1,bSort:!0,bSortMulti:!0,bSortCellsTop:!1,bSortClasses:!0,bStateSave:!1,fnCreatedRow:null,fnDrawCallback:null,fnFooterCallback:null,fnFormatNumber:function(a){return a.toString().replace(/\B(?=(\d{3})+(?!\d))/g,this.oLanguage.sThousands)},fnHeaderCallback:null,fnInfoCallback:null, +fnInitComplete:null,fnPreDrawCallback:null,fnRowCallback:null,fnServerData:null,fnServerParams:null,fnStateLoadCallback:function(a){try{return JSON.parse((-1===a.iStateDuration?sessionStorage:localStorage).getItem("DataTables_"+a.sInstance+"_"+location.pathname))}catch(b){}},fnStateLoadParams:null,fnStateLoaded:null,fnStateSaveCallback:function(a,b){try{(-1===a.iStateDuration?sessionStorage:localStorage).setItem("DataTables_"+a.sInstance+"_"+location.pathname,JSON.stringify(b))}catch(c){}},fnStateSaveParams:null, +iStateDuration:7200,iDeferLoading:null,iDisplayLength:10,iDisplayStart:0,iTabIndex:0,oClasses:{},oLanguage:{oAria:{sSortAscending:": activate to sort column ascending",sSortDescending:": activate to sort column descending"},oPaginate:{sFirst:"First",sLast:"Last",sNext:"Next",sPrevious:"Previous"},sEmptyTable:"No data available in table",sInfo:"Showing _START_ to _END_ of _TOTAL_ entries",sInfoEmpty:"Showing 0 to 0 of 0 entries",sInfoFiltered:"(filtered from _MAX_ total entries)",sInfoPostFix:"",sDecimal:"", +sThousands:",",sLengthMenu:"Show _MENU_ entries",sLoadingRecords:"Loading...",sProcessing:"Processing...",sSearch:"Search:",sSearchPlaceholder:"",sUrl:"",sZeroRecords:"No matching records found"},oSearch:g.extend({},p.models.oSearch),sAjaxDataProp:"data",sAjaxSource:null,sDom:"lfrtip",searchDelay:null,sPaginationType:"simple_numbers",sScrollX:"",sScrollXInner:"",sScrollY:"",sServerMethod:"GET",renderer:null};V(p.defaults);p.defaults.column={aDataSort:null,iDataSort:-1,asSorting:["asc","desc"],bSearchable:!0, +bSortable:!0,bVisible:!0,fnCreatedCell:null,mData:null,mRender:null,sCellType:"td",sClass:"",sContentPadding:"",sDefaultContent:null,sName:"",sSortDataType:"std",sTitle:null,sType:null,sWidth:null};V(p.defaults.column);p.models.oSettings={oFeatures:{bAutoWidth:null,bDeferRender:null,bFilter:null,bInfo:null,bLengthChange:null,bPaginate:null,bProcessing:null,bServerSide:null,bSort:null,bSortMulti:null,bSortClasses:null,bStateSave:null},oScroll:{bCollapse:null,iBarWidth:0,sX:null,sXInner:null,sY:null}, +oLanguage:{fnInfoCallback:null},oBrowser:{bScrollOversize:!1,bScrollbarLeft:!1},ajax:null,aanFeatures:[],aoData:[],aiDisplay:[],aiDisplayMaster:[],aoColumns:[],aoHeader:[],aoFooter:[],oPreviousSearch:{},aoPreSearchCols:[],aaSorting:null,aaSortingFixed:[],asStripeClasses:null,asDestroyStripes:[],sDestroyWidth:0,aoRowCallback:[],aoHeaderCallback:[],aoFooterCallback:[],aoDrawCallback:[],aoRowCreatedCallback:[],aoPreDrawCallback:[],aoInitComplete:[],aoStateSaveParams:[],aoStateLoadParams:[],aoStateLoaded:[], +sTableId:"",nTable:null,nTHead:null,nTFoot:null,nTBody:null,nTableWrapper:null,bDeferLoading:!1,bInitialised:!1,aoOpenRows:[],sDom:null,searchDelay:null,sPaginationType:"two_button",iStateDuration:0,aoStateSave:[],aoStateLoad:[],oSavedState:null,oLoadedState:null,sAjaxSource:null,sAjaxDataProp:null,bAjaxDataGet:!0,jqXHR:null,json:l,oAjaxData:l,fnServerData:null,aoServerParams:[],sServerMethod:null,fnFormatNumber:null,aLengthMenu:null,iDraw:0,bDrawing:!1,iDrawError:-1,_iDisplayLength:10,_iDisplayStart:0, +_iRecordsTotal:0,_iRecordsDisplay:0,bJUI:null,oClasses:{},bFiltered:!1,bSorted:!1,bSortCellsTop:null,oInit:null,aoDestroyCallback:[],fnRecordsTotal:function(){return"ssp"==A(this)?1*this._iRecordsTotal:this.aiDisplayMaster.length},fnRecordsDisplay:function(){return"ssp"==A(this)?1*this._iRecordsDisplay:this.aiDisplay.length},fnDisplayEnd:function(){var a=this._iDisplayLength,b=this._iDisplayStart,c=b+a,e=this.aiDisplay.length,d=this.oFeatures,f=d.bPaginate;return d.bServerSide?!1===f||-1===a?b+e: +Math.min(b+a,this._iRecordsDisplay):!f||c>e||-1===a?e:c},oInstance:null,sInstance:null,iTabIndex:0,nScrollHead:null,nScrollFoot:null,aLastSort:[],oPlugins:{}};p.ext=w={classes:{},errMode:"alert",feature:[],search:[],internal:{},legacy:{ajax:null},pager:{},renderer:{pageButton:{},header:{}},order:{},type:{detect:[],search:{},order:{}},_unique:0,fnVersionCheck:p.fnVersionCheck,iApiIndex:0,oJUIClasses:{},sVersion:p.version};g.extend(w,{afnFiltering:w.search,aTypes:w.type.detect,ofnSearch:w.type.search, +oSort:w.type.order,afnSortData:w.order,aoFeatures:w.feature,oApi:w.internal,oStdClasses:w.classes,oPagination:w.pager});g.extend(p.ext.classes,{sTable:"dataTable",sNoFooter:"no-footer",sPageButton:"paginate_button",sPageButtonActive:"current",sPageButtonDisabled:"disabled",sStripeOdd:"odd",sStripeEven:"even",sRowEmpty:"dataTables_empty",sWrapper:"dataTables_wrapper",sFilter:"dataTables_filter",sInfo:"dataTables_info",sPaging:"dataTables_paginate paging_",sLength:"dataTables_length",sProcessing:"dataTables_processing", +sSortAsc:"sorting_asc",sSortDesc:"sorting_desc",sSortable:"sorting",sSortableAsc:"sorting_asc_disabled",sSortableDesc:"sorting_desc_disabled",sSortableNone:"sorting_disabled",sSortColumn:"sorting_",sFilterInput:"",sLengthSelect:"",sScrollWrapper:"dataTables_scroll",sScrollHead:"dataTables_scrollHead",sScrollHeadInner:"dataTables_scrollHeadInner",sScrollBody:"dataTables_scrollBody",sScrollFoot:"dataTables_scrollFoot",sScrollFootInner:"dataTables_scrollFootInner",sHeaderTH:"",sFooterTH:"",sSortJUIAsc:"", +sSortJUIDesc:"",sSortJUI:"",sSortJUIAscAllowed:"",sSortJUIDescAllowed:"",sSortJUIWrapper:"",sSortIcon:"",sJUIHeader:"",sJUIFooter:""});var Ca="",Ca="",E=Ca+"ui-state-default",ia=Ca+"css_right ui-icon ui-icon-",Wb=Ca+"fg-toolbar ui-toolbar ui-widget-header ui-helper-clearfix";g.extend(p.ext.oJUIClasses,p.ext.classes,{sPageButton:"fg-button ui-button "+E,sPageButtonActive:"ui-state-disabled",sPageButtonDisabled:"ui-state-disabled",sPaging:"dataTables_paginate fg-buttonset ui-buttonset fg-buttonset-multi ui-buttonset-multi paging_", +sSortAsc:E+" sorting_asc",sSortDesc:E+" sorting_desc",sSortable:E+" sorting",sSortableAsc:E+" sorting_asc_disabled",sSortableDesc:E+" sorting_desc_disabled",sSortableNone:E+" sorting_disabled",sSortJUIAsc:ia+"triangle-1-n",sSortJUIDesc:ia+"triangle-1-s",sSortJUI:ia+"carat-2-n-s",sSortJUIAscAllowed:ia+"carat-1-n",sSortJUIDescAllowed:ia+"carat-1-s",sSortJUIWrapper:"DataTables_sort_wrapper",sSortIcon:"DataTables_sort_icon",sScrollHead:"dataTables_scrollHead "+E,sScrollFoot:"dataTables_scrollFoot "+E, +sHeaderTH:E,sFooterTH:E,sJUIHeader:Wb+" ui-corner-tl ui-corner-tr",sJUIFooter:Wb+" ui-corner-bl ui-corner-br"});var Lb=p.ext.pager;g.extend(Lb,{simple:function(){return["previous","next"]},full:function(){return["first","previous","next","last"]},simple_numbers:function(a,b){return["previous",Va(a,b),"next"]},full_numbers:function(a,b){return["first","previous",Va(a,b),"next","last"]},_numbers:Va,numbers_length:7});g.extend(!0,p.ext.renderer,{pageButton:{_:function(a,b,c,e,d,f){var h=a.oClasses,i= +a.oLanguage.oPaginate,j,l,m=0,o=function(b,e){var k,p,r,q,s=function(b){Sa(a,b.data.action,true)};k=0;for(p=e.length;k").appendTo(b);o(r,q)}else{l=j="";switch(q){case "ellipsis":b.append("");break;case "first":j=i.sFirst;l=q+(d>0?"":" "+h.sPageButtonDisabled);break;case "previous":j=i.sPrevious;l=q+(d>0?"":" "+h.sPageButtonDisabled);break;case "next":j=i.sNext;l=q+(d",{"class":h.sPageButton+" "+l,"aria-controls":a.sTableId,"data-dt-idx":m,tabindex:a.iTabIndex,id:c===0&&typeof q==="string"?a.sTableId+"_"+q:null}).html(j).appendTo(b);Ua(r,{action:q},s);m++}}}};try{var k=g(P.activeElement).data("dt-idx");o(g(b).empty(),e);k!==null&&g(b).find("[data-dt-idx="+k+"]").focus()}catch(p){}}}});g.extend(p.ext.type.detect,[function(a,b){var c=b.oLanguage.sDecimal; +return Ya(a,c)?"num"+c:null},function(a){if(a&&!(a instanceof Date)&&(!$b.test(a)||!ac.test(a)))return null;var b=Date.parse(a);return null!==b&&!isNaN(b)||H(a)?"date":null},function(a,b){var c=b.oLanguage.sDecimal;return Ya(a,c,!0)?"num-fmt"+c:null},function(a,b){var c=b.oLanguage.sDecimal;return Qb(a,c)?"html-num"+c:null},function(a,b){var c=b.oLanguage.sDecimal;return Qb(a,c,!0)?"html-num-fmt"+c:null},function(a){return H(a)||"string"===typeof a&&-1!==a.indexOf("<")?"html":null}]);g.extend(p.ext.type.search, +{html:function(a){return H(a)?a:"string"===typeof a?a.replace(Nb," ").replace(Aa,""):""},string:function(a){return H(a)?a:"string"===typeof a?a.replace(Nb," "):a}});var za=function(a,b,c,e){if(0!==a&&(!a||"-"===a))return-Infinity;b&&(a=Pb(a,b));a.replace&&(c&&(a=a.replace(c,"")),e&&(a=a.replace(e,"")));return 1*a};g.extend(w.type.order,{"date-pre":function(a){return Date.parse(a)||0},"html-pre":function(a){return H(a)?"":a.replace?a.replace(/<.*?>/g,"").toLowerCase():a+""},"string-pre":function(a){return H(a)? +"":"string"===typeof a?a.toLowerCase():!a.toString?"":a.toString()},"string-asc":function(a,b){return ab?1:0},"string-desc":function(a,b){return ab?-1:0}});cb("");g.extend(!0,p.ext.renderer,{header:{_:function(a,b,c,e){g(a.nTable).on("order.dt.DT",function(d,f,h,g){if(a===f){d=c.idx;b.removeClass(c.sSortingClass+" "+e.sSortAsc+" "+e.sSortDesc).addClass(g[d]=="asc"?e.sSortAsc:g[d]=="desc"?e.sSortDesc:c.sSortingClass)}})},jqueryui:function(a,b,c,e){g("
      ").addClass(e.sSortJUIWrapper).append(b.contents()).append(g("").addClass(e.sSortIcon+ +" "+c.sSortingClassJUI)).appendTo(b);g(a.nTable).on("order.dt.DT",function(d,f,g,i){if(a===f){d=c.idx;b.removeClass(e.sSortAsc+" "+e.sSortDesc).addClass(i[d]=="asc"?e.sSortAsc:i[d]=="desc"?e.sSortDesc:c.sSortingClass);b.find("span."+e.sSortIcon).removeClass(e.sSortJUIAsc+" "+e.sSortJUIDesc+" "+e.sSortJUI+" "+e.sSortJUIAscAllowed+" "+e.sSortJUIDescAllowed).addClass(i[d]=="asc"?e.sSortJUIAsc:i[d]=="desc"?e.sSortJUIDesc:c.sSortingClassJUI)}})}}});p.render={number:function(a,b,c,e){return{display:function(d){var f= +0>d?"-":"",d=Math.abs(parseFloat(d)),g=parseInt(d,10),d=c?b+(d-g).toFixed(c).substring(2):"";return f+(e||"")+g.toString().replace(/\B(?=(\d{3})+(?!\d))/g,a)+d}}}};g.extend(p.ext.internal,{_fnExternApiFunc:Mb,_fnBuildAjax:qa,_fnAjaxUpdate:jb,_fnAjaxParameters:sb,_fnAjaxUpdateDraw:tb,_fnAjaxDataSrc:ra,_fnAddColumn:Ea,_fnColumnOptions:ja,_fnAdjustColumnSizing:X,_fnVisibleToColumnIndex:ka,_fnColumnIndexToVisible:$,_fnVisbleColumns:aa,_fnGetColumns:Z,_fnColumnTypes:Ga,_fnApplyColumnDefs:hb,_fnHungarianMap:V, +_fnCamelToHungarian:G,_fnLanguageCompat:O,_fnBrowserDetect:fb,_fnAddData:I,_fnAddTr:la,_fnNodeToDataIndex:function(a,b){return b._DT_RowIndex!==l?b._DT_RowIndex:null},_fnNodeToColumnIndex:function(a,b,c){return g.inArray(c,a.aoData[b].anCells)},_fnGetCellData:v,_fnSetCellData:Ha,_fnSplitObjNotation:Ja,_fnGetObjectDataFn:W,_fnSetObjectDataFn:Q,_fnGetDataMaster:Ka,_fnClearTable:na,_fnDeleteIndex:oa,_fnInvalidate:ca,_fnGetRowElements:ma,_fnCreateTr:Ia,_fnBuildHead:ib,_fnDrawHead:ea,_fnDraw:L,_fnReDraw:M, +_fnAddOptionsHtml:lb,_fnDetectHeader:da,_fnGetUniqueThs:pa,_fnFeatureHtmlFilter:nb,_fnFilterComplete:fa,_fnFilterCustom:wb,_fnFilterColumn:vb,_fnFilter:ub,_fnFilterCreateSearch:Pa,_fnEscapeRegex:ua,_fnFilterData:xb,_fnFeatureHtmlInfo:qb,_fnUpdateInfo:Ab,_fnInfoMacros:Bb,_fnInitialise:ga,_fnInitComplete:sa,_fnLengthChange:Qa,_fnFeatureHtmlLength:mb,_fnFeatureHtmlPaginate:rb,_fnPageChange:Sa,_fnFeatureHtmlProcessing:ob,_fnProcessingDisplay:B,_fnFeatureHtmlTable:pb,_fnScrollDraw:Y,_fnApplyToChildren:F, +_fnCalculateColumnWidths:Fa,_fnThrottle:ta,_fnConvertToWidth:Cb,_fnScrollingWidthAdjust:Eb,_fnGetWidestNode:Db,_fnGetMaxLenString:Fb,_fnStringToCss:s,_fnScrollBarWidth:Gb,_fnSortFlatten:T,_fnSort:kb,_fnSortAria:Ib,_fnSortListener:Ta,_fnSortAttachListener:Na,_fnSortingClasses:wa,_fnSortData:Hb,_fnSaveState:xa,_fnLoadState:Jb,_fnSettingsFromNode:ya,_fnLog:R,_fnMap:D,_fnBindAction:Ua,_fnCallbackReg:x,_fnCallbackFire:u,_fnLengthOverflow:Ra,_fnRenderer:Oa,_fnDataSource:A,_fnRowAttributes:La,_fnCalculateEnd:function(){}}); +g.fn.dataTable=p;g.fn.dataTableSettings=p.settings;g.fn.dataTableExt=p.ext;g.fn.DataTable=function(a){return g(this).dataTable(a).api()};g.each(p,function(a,b){g.fn.DataTable[a]=b});return g.fn.dataTable};"function"===typeof define&&define.amd?define("datatables",["jquery"],O):"object"===typeof exports?O(require("jquery")):jQuery&&!jQuery.fn.dataTable&&O(jQuery)})(window,document); diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js new file mode 100644 index 0000000000000..14925bf93d0f1 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js @@ -0,0 +1,592 @@ +/* +Shameless port of a shameless port +@defunkt => @janl => @aq + +See http://github.com/defunkt/mustache for more info. +*/ + +;(function($) { + +/*! + * mustache.js - Logic-less {{mustache}} templates with JavaScript + * http://github.com/janl/mustache.js + */ + +/*global define: false*/ + +(function (root, factory) { + if (typeof exports === "object" && exports) { + factory(exports); // CommonJS + } else { + var mustache = {}; + factory(mustache); + if (typeof define === "function" && define.amd) { + define(mustache); // AMD + } else { + root.Mustache = mustache; // ++ + ++ + } else if (requestedIncomplete) {

      No incomplete applications found!

      } else {

      No completed applications found!

      ++ -

      Did you specify the correct logging directory? - Please verify your setting of - spark.history.fs.logDirectory and whether you have the permissions to - access it.
      It is also possible that your application did not run to - completion or did not stop the SparkContext. -

      +

      Did you specify the correct logging directory? + Please verify your setting of + spark.history.fs.logDirectory and whether you have the permissions to + access it.
      It is also possible that your application did not run to + completion or did not stop the SparkContext. +

      } - } - - { + } + + + { if (requestedIncomplete) { "Back to completed applications" } else { "Show incomplete applications" } - } - -
      + } +
      +
      UIUtils.basicSparkPage(content, "History Server") } - private val appHeader = Seq( - "App ID", - "App Name", - "Started", - "Completed", - "Duration", - "Spark User", - "Last Updated") - - private val appWithAttemptHeader = Seq( - "App ID", - "App Name", - "Attempt ID", - "Started", - "Completed", - "Duration", - "Spark User", - "Last Updated") - - private def rangeIndices( - range: Seq[Int], - condition: Int => Boolean, - showIncomplete: Boolean): Seq[Node] = { - range.filter(condition).map(nextPage => - {nextPage} ) - } - - private def attemptRow( - renderAttemptIdColumn: Boolean, - info: ApplicationHistoryInfo, - attempt: ApplicationAttemptInfo, - isFirst: Boolean): Seq[Node] = { - val uiAddress = UIUtils.prependBaseUri(HistoryServer.getAttemptURI(info.id, attempt.attemptId)) - val startTime = UIUtils.formatDate(attempt.startTime) - val endTime = if (attempt.endTime > 0) UIUtils.formatDate(attempt.endTime) else "-" - val duration = - if (attempt.endTime > 0) { - UIUtils.formatDuration(attempt.endTime - attempt.startTime) - } else { - "-" - } - val lastUpdated = UIUtils.formatDate(attempt.lastUpdated) - - { - if (isFirst) { - if (info.attempts.size > 1 || renderAttemptIdColumn) { - - {info.id} - - {info.name} - } else { - {info.id} - {info.name} - } - } else { - Nil - } - } - { - if (renderAttemptIdColumn) { - if (info.attempts.size > 1 && attempt.attemptId.isDefined) { - {attempt.attemptId.get} - } else { -   - } - } else { - Nil - } - } - {startTime} - {endTime} - - {duration} - {attempt.sparkUser} - {lastUpdated} - - } - - private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - attemptRow(false, info, info.attempts.head, true) - } - - private def appWithAttemptRow(info: ApplicationHistoryInfo): Seq[Node] = { - attemptRow(true, info, info.attempts.head, true) ++ - info.attempts.drop(1).flatMap(attemptRow(true, info, _, false)) - } - - private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + Array( - "page=" + linkPage, - "showIncomplete=" + showIncomplete - ).mkString("&")) + private def makePageLink(showIncomplete: Boolean): String = { + UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 0fc0fb59d861f..0f30183682469 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -71,6 +71,13 @@ private[spark] object ApplicationsListResource { attemptId = internalAttemptInfo.attemptId, startTime = new Date(internalAttemptInfo.startTime), endTime = new Date(internalAttemptInfo.endTime), + duration = + if (internalAttemptInfo.endTime > 0) { + internalAttemptInfo.endTime - internalAttemptInfo.startTime + } else { + 0 + }, + lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, completed = internalAttemptInfo.completed ) @@ -93,6 +100,13 @@ private[spark] object ApplicationsListResource { attemptId = None, startTime = new Date(internal.startTime), endTime = new Date(internal.endTime), + duration = + if (internal.endTime > 0) { + internal.endTime - internal.startTime + } else { + 0 + }, + lastUpdated = new Date(internal.endTime), sparkUser = internal.desc.user, completed = completed )) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 3adf5b1109af4..2b0079f5fd62e 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -35,6 +35,8 @@ class ApplicationAttemptInfo private[spark]( val attemptId: Option[String], val startTime: Date, val endTime: Date, + val lastUpdated: Date, + val duration: Long, val sparkUser: String, val completed: Boolean = false) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index cf45414c4f786..6cc30eeaf5d82 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -114,6 +114,8 @@ private[spark] class SparkUI private ( attemptId = None, startTime = new Date(startTime), endTime = new Date(-1), + duration = 0, + lastUpdated = new Date(startTime), sparkUser = "", completed = false )) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 1949c4b3cbf42..4ebee9093d41c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -157,11 +157,22 @@ private[spark] object UIUtils extends Logging { def commonHeaderNodes: Seq[Node] = { + + + + + + + + + diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index d575bf2f284b9..5bbb4ceb97228 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -45,6 +55,8 @@ "attempts" : [ { "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] @@ -54,6 +66,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -63,7 +77,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index d575bf2f284b9..5bbb4ceb97228 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -45,6 +55,8 @@ "attempts" : [ { "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] @@ -54,6 +66,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -63,7 +77,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json index 483632a3956ed..3f80a529a08b9 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json @@ -4,7 +4,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json index 4b85690fd9199..508bdc17efe9f 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -13,7 +15,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 15c2de8ef99ea..5dca7d73de0cc 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -46,8 +56,10 @@ { "startTime": "2015-02-28T00:02:38.277GMT", "endTime": "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser": "irashid", "completed": true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json index 07489ad96414a..cca32c791074f 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json @@ -4,7 +4,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json index 8f3d7160c723f..1ea1779e8369d 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json @@ -5,13 +5,17 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 18659fc0c18de..be55b2e0fe1b7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -139,7 +139,24 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers code should be (HttpServletResponse.SC_OK) jsonOpt should be ('defined) errOpt should be (None) - val json = jsonOpt.get + val jsonOrg = jsonOpt.get + + // SPARK-10873 added the lastUpdated field for each application's attempt, + // the REST API returns the last modified time of EVENT LOG file for this field. + // It is not applicable to hard-code this dynamic field in a static expected file, + // so here we skip checking the lastUpdated field's value (setting it as ""). + val json = if (jsonOrg.indexOf("lastUpdated") >= 0) { + val subStrings = jsonOrg.split(",") + for (i <- subStrings.indices) { + if (subStrings(i).indexOf("lastUpdated") >= 0) { + subStrings(i) = "\"lastUpdated\":\"\"" + } + } + subStrings.mkString(",") + } else { + jsonOrg + } + val exp = IOUtils.toString(new FileInputStream( new File(expRoot, HistoryServerSuite.sanitizePath(name) + "_expectation.json"))) // compare the ASTs so formatting differences don't cause failures @@ -241,30 +258,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } - test("generate history page with relative links") { - val historyServer = mock[HistoryServer] - val request = mock[HttpServletRequest] - val ui = mock[SparkUI] - val link = "/history/app1" - val info = new ApplicationHistoryInfo("app1", "app1", - List(ApplicationAttemptInfo(None, 0, 2, 1, "xxx", true))) - when(historyServer.getApplicationList()).thenReturn(Seq(info)) - when(ui.basePath).thenReturn(link) - when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) - val page = new HistoryPage(historyServer) - - // when - val response = page.render(request) - - // then - val links = response \\ "a" - val justHrefs = for { - l <- links - attrs <- l.attribute("href") - } yield (attrs.toString) - justHrefs should contain (UIUtils.prependBaseUri(resource = link)) - } - test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 968a2903f3010..a3ae4d2b730ff 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,6 +45,10 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"), // SPARK-12600 Remove SQL deprecated methods ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), From c5f745ede01831b59c57effa7de88c648b82c13d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Jan 2016 10:24:23 -0800 Subject: [PATCH 017/107] [SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible. generated code comparison for `hash(int, double, string, array)`: **before:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); } /* input[1, double] */ double value5 = i.getDouble(1); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); } /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { int result10 = value1; for (int index11 = 0; index11 < value9.numElements(); index11++) { if (!value9.isNullAt(index11)) { final int element12 = value9.getInt(index11); result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10); } } value1 = result10; } } ``` **after:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); /* input[1, double] */ double value5 = i.getDouble(1); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { for (int index10 = 0; index10 < value9.numElements(); index10++) { final int element11 = value9.getInt(index10); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1); } } rowWriter14.write(0, value1); return result12; } ``` Author: Wenchen Fan Closes #10974 from cloud-fan/codegen. --- .../spark/sql/catalyst/expressions/misc.scala | 155 ++++++++---------- 1 file changed, 69 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 493e0aae01af7..8480c3f9a12f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" - val childrenHash = children.zipWithIndex.map { - case (child, dt) => - val childGen = child.gen(ctx) - val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx) - s""" - ${childGen.code} - if (!${childGen.isNull}) { - ${childHash.code} - ${ev.value} = ${childHash.value}; - } - """ + val childrenHash = children.map { child => + val childGen = child.gen(ctx) + childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, ev.value, ctx) + } }.mkString("\n") + s""" int ${ev.value} = $seed; $childrenHash """ } + private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execution + } + """ + } else { + "\n" + execution + } + } + + private def nullSafeElementHash( + input: String, + index: String, + nullable: Boolean, + elementType: DataType, + result: String, + ctx: CodegenContext): String = { + val element = ctx.freshName("element") + + generateNullCheck(nullable, s"$input.isNullAt($index)") { + s""" + final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + ${computeHash(element, elementType, result, ctx)} + """ + } + } + private def computeHash( input: String, dataType: DataType, - seed: String, - ctx: CodegenContext): ExprCode = { + result: String, + ctx: CodegenContext): String = { val hasher = classOf[Murmur3_x86_32].getName - def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") - def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v) + + def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" + def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" + def hashBytes(b: String): String = + s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" dataType match { - case NullType => inlineValue(seed) + case NullType => "" case BooleanType => hashInt(s"$input ? 1 : 0") case ByteType | ShortType | IntegerType | DateType => hashInt(input) case LongType | TimestampType => hashLong(input) @@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression hashLong(s"$input.toUnscaledLong()") } else { val bytes = ctx.freshName("bytes") - val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" - val offset = "Platform.BYTE_ARRAY_OFFSET" - val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" - ExprCode(code, "false", result) + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${hashBytes(bytes)} + """ } case CalendarIntervalType => - val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" - val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)" - inlineValue(monthsHash) - case BinaryType => - val offset = "Platform.BYTE_ARRAY_OFFSET" - inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)") + val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" + s"$result = $hasher.hashInt($input.months, $microsecondsHash);" + case BinaryType => hashBytes(input) case StringType => val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" - inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)") + s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - case ArrayType(et, _) => - val result = ctx.freshName("result") + case ArrayType(et, containsNull) => val index = ctx.freshName("index") - val element = ctx.freshName("element") - val elementHash = computeHash(element, et, result, ctx) - val code = - s""" - int $result = $seed; - for (int $index = 0; $index < $input.numElements(); $index++) { - if (!$input.isNullAt($index)) { - final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; - ${elementHash.code} - $result = ${elementHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} + } + """ - case MapType(kt, vt, _) => - val result = ctx.freshName("result") + case MapType(kt, vt, valueContainsNull) => val index = ctx.freshName("index") val keys = ctx.freshName("keys") val values = ctx.freshName("values") - val key = ctx.freshName("key") - val value = ctx.freshName("value") - val keyHash = computeHash(key, kt, result, ctx) - val valueHash = computeHash(value, vt, result, ctx) - val code = - s""" - int $result = $seed; - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - for (int $index = 0; $index < $input.numElements(); $index++) { - final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; - ${keyHash.code} - $result = ${keyHash.value}; - if (!$values.isNullAt($index)) { - final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; - ${valueHash.code} - $result = ${valueHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, kt, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} + } + """ case StructType(fields) => - val result = ctx.freshName("result") - val fieldsHash = fields.map(_.dataType).zipWithIndex.map { - case (dt, index) => - val field = ctx.freshName("field") - val fieldHash = computeHash(field, dt, result, ctx) - s""" - if (!$input.isNullAt($index)) { - final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)}; - ${fieldHash.code} - $result = ${fieldHash.value}; - } - """ + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) }.mkString("\n") - val code = - s""" - int $result = $seed; - $fieldsHash - """ - ExprCode(code, "false", result) - case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) + case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) } } } From 5f686cc8b74ea9e36f56c31f14df90d134fd9343 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 29 Jan 2016 11:22:12 -0800 Subject: [PATCH 018/107] [SPARK-12656] [SQL] Implement Intersect with Left-semi Join Our current Intersect physical operator simply delegates to RDD.intersect. We should remove the Intersect physical operator and simply transform a logical intersect into a semi-join with distinct. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins). After a search, I found one of the mainstream RDBMS did the same. In their query explain, Intersect is replaced by Left-semi Join. Left-semi Join could help outer-join elimination in Optimizer, as shown in the PR: https://github.com/apache/spark/pull/10566 Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10630 from gatorsmile/IntersectBySemiJoin. --- .../sql/catalyst/analysis/Analyzer.scala | 113 ++++++++++-------- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 ++- .../sql/catalyst/optimizer/Optimizer.scala | 45 ++++--- .../plans/logical/basicOperators.scala | 32 +++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 5 + .../optimizer/AggregateOptimizeSuite.scala | 12 -- .../optimizer/ReplaceOperatorSuite.scala | 59 +++++++++ .../optimizer/SetOperationSuite.scala | 15 +-- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../spark/sql/execution/basicOperators.scala | 12 -- .../org/apache/spark/sql/DataFrameSuite.scala | 21 ++++ 11 files changed, 211 insertions(+), 122 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 33d76eeb21287..5fe700ee00673 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -344,6 +344,63 @@ class Analyzer( } } + /** + * Generate a new logical plan for the right child with different expression IDs + * for all conflicting attributes. + */ + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + + s"between $left and $right") + + right.collect { + // Handle base relations that might appear more than once. + case oldVersion: MultiInstanceRelation + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.newInstance() + (oldVersion, newVersion) + + // Handle projects that create conflicting aliases. + case oldVersion @ Project(projectList, _) + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) + + case oldVersion @ Aggregate(_, aggregateExpressions, _) + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + + case oldVersion: Generate + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) + + case oldVersion @ Window(_, windowExpressions, _, _, child) + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => + (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) + } + // Only handle first case, others will be fixed on the next pass. + .headOption match { + case None => + /* + * No result implies that there is a logical plan node that produces new references + * that this rule cannot handle. When that is the case, there must be another rule + * that resolves these conflicts. Otherwise, the analysis will fail. + */ + right + case Some((oldRelation, newRelation)) => + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + newRight + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -388,57 +445,11 @@ class Analyzer( .map(_.asInstanceOf[NamedExpression]) a.copy(aggregateExpressions = expanded) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if !j.selfJoinResolved => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") - - right.collect { - // Handle base relations that might appear more than once. - case oldVersion: MultiInstanceRelation - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = oldVersion.newInstance() - (oldVersion, newVersion) - - // Handle projects that create conflicting aliases. - case oldVersion @ Project(projectList, _) - if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) - - case oldVersion @ Aggregate(_, aggregateExpressions, _) - if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) - - case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => - val newOutput = oldVersion.generatorOutput.map(_.newInstance()) - (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - - case oldVersion @ Window(_, windowExpressions, _, _, child) - if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) - .nonEmpty => - (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) - } - // Only handle first case, others will be fixed on the next pass. - .headOption match { - case None => - /* - * No result implies that there is a logical plan node that produces new references - * that this rule cannot handle. When that is the case, there must be another rule - * that resolves these conflicts. Otherwise, the analysis will fail. - */ - j - case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) - } - } - j.copy(right = newRight) - } + // To resolve duplicate expression IDs for Join and Intersect + case j @ Join(left, right, _, _) if !j.duplicateResolved => + j.copy(right = dedupRight(left, right)) + case i @ Intersect(left, right) if !i.duplicateResolved => + i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on grandchild diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f2e78d97442e3..4a2f2b8bc6e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -214,9 +214,8 @@ trait CheckAnalysis { s"""Only a single table generating function is allowed in a SELECT clause, found: | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) + case j: Join if !j.duplicateResolved => + val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) failAnalysis( s""" |Failure when resolving conflicting references in Join: @@ -224,6 +223,15 @@ trait CheckAnalysis { |Conflicting attributes: ${conflictingAttributes.mkString(",")} |""".stripMargin) + case i: Intersect if !i.duplicateResolved => + val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Intersect: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6addc2080648b..f156b5d10acc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Replace Operators", FixedPoint(100), + ReplaceIntersectWithSemiJoin, + ReplaceDistinctWithAggregate) :: Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down @@ -124,18 +126,13 @@ object EliminateSerialization extends Rule[LogicalPlan] { } /** - * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Pushes certain operations to both sides of a Union or Except operator. * Operations that are safe to pushdown are listed as follows. * Union: * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. * - * Intersect: - * It is not safe to pushdown Projections through it because we need to get the - * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * with deterministic condition. - * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters @@ -153,7 +150,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Rewrites an expression so that it can be pushed to the right side of a - * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * Union or Except operator. This method relies on the fact that the output attributes * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { @@ -210,17 +207,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) - // Push down filter through INTERSECT - case Filter(condition, Intersect(left, right)) => - val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(left, right) - Filter(nondeterministic, - Intersect( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) - // Push down filter through EXCEPT case Filter(condition, Except(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) @@ -1054,6 +1040,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator. + * {{{ + * SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2 + * ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2 + * }}} + * + * Note: + * 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL. + * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated + * join conditions will be incorrect. + */ +object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right) => + assert(left.output.size == right.output.size) + val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } + Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e9c970cd08088..16f4b355b1b6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -90,12 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - final override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) @@ -103,15 +99,30 @@ private[sql] object SetOperation { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + + // Intersect are only resolved if they don't introduce ambiguous expression ids, + // since the Optimizer will convert Intersect to Join. + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && + duplicateResolved } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } /** Factory for constructing new `Union` nodes. */ @@ -169,13 +180,13 @@ case class Join( } } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { childrenResolved && expressions.forall(_.resolved) && - selfJoinResolved && + duplicateResolved && condition.forall(_.dataType == BooleanType) } } @@ -249,7 +260,7 @@ case class Range( end: Long, step: Long, numSlices: Int, - output: Seq[Attribute]) extends LeafNode { + output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { require(step != 0, "step cannot be 0") val numElements: BigInt = { val safeStart = BigInt(start) @@ -262,6 +273,9 @@ case class Range( } } + override def newInstance(): Range = + Range(start, end, step, numSlices, output.map(_.newInstance())) + override def statistics: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ab680282208c8..1938bce02a177 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -154,6 +154,11 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } + test("self intersect should resolve duplicate expression IDs") { + val plan = testRelation.intersect(testRelation) + assertAnalysisSuccess(plan) + } + test("SPARK-8654: invalid CAST in NULL IN(...) expression") { val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 37148a226f293..a4a12c0d62e92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Nil } - test("replace distinct with aggregate") { - val input = LocalRelation('a.int, 'b.int) - - val query = Distinct(input) - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = Aggregate(input.output, input.output, input) - - comparePlans(optimized, correctAnswer) - } - test("remove literals in grouping expression") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala new file mode 100644 index 0000000000000..f8ae5d9be2084 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceOperatorSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace Operators", FixedPoint(100), + ReplaceDistinctWithAggregate, + ReplaceIntersectWithSemiJoin) :: Nil + } + + test("replace Intersect with Left-semi Join") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = Intersect(table1, table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Distinct with Aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 2283f7c008ba2..b8ea32b4dfe01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest { val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) - val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) test("union: combine unions into one unions") { @@ -57,19 +56,12 @@ class SetOperationSuite extends PlanTest { comparePlans(combinedUnionsOptimized, unionOptimized3) } - test("intersect/except: filter to each side") { - val intersectQuery = testIntersect.where('b < 10) + test("except: filter to each side") { val exceptQuery = testExcept.where('c >= 5) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - val intersectCorrectAnswer = - Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze val exceptCorrectAnswer = Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze - comparePlans(intersectOptimized, intersectCorrectAnswer) comparePlans(exceptOptimized, exceptCorrectAnswer) } @@ -95,13 +87,8 @@ class SetOperationSuite extends PlanTest { } test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - comparePlans(intersectOptimized, intersectQuery.analyze) comparePlans(exceptOptimized, exceptQuery.analyze) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 60fbb595e5758..9293e55141757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -298,6 +298,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") + case logical.Intersect(left, right) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by semi-join in the optimizer") case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil @@ -340,8 +343,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil - case logical.Intersect(left, right) => - execution.Intersect(planLater(left), planLater(right)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.Generate( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e7a73d5fbb4bf..fd81531c9316a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -420,18 +420,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { } } -/** - * Returns the rows in left that also appear in right using the built in spark - * intersection function. - */ -case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output: Seq[Attribute] = children.head.output - - protected override def doExecute(): RDD[InternalRow] = { - left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) - } -} - /** * A plan node that does nothing but lie about the output of its child. Used to spice a * (hopefully structurally equivalent) tree from a different optimization sequence into an already diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 09bbe57a43ceb..4ff99bdf2937d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -349,6 +349,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) } test("intersect - nullability") { From 2b027e9a386fe4009f61ad03b169335af5a9a5c6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 29 Jan 2016 12:01:13 -0800 Subject: [PATCH 019/107] [SPARK-12818] Polishes spark-sketch module Fixes various minor code and Javadoc styling issues. Author: Cheng Lian Closes #10985 from liancheng/sketch-polishing. --- .../apache/spark/util/sketch/BitArray.java | 2 +- .../apache/spark/util/sketch/BloomFilter.java | 111 ++++++++++-------- .../spark/util/sketch/BloomFilterImpl.java | 40 ++++--- .../spark/util/sketch/CountMinSketch.java | 26 ++-- .../spark/util/sketch/CountMinSketchImpl.java | 12 ++ .../org/apache/spark/util/sketch/Utils.java | 2 +- 6 files changed, 110 insertions(+), 83 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 2a0484e324b13..480a0a79db32d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -22,7 +22,7 @@ import java.io.IOException; import java.util.Arrays; -public final class BitArray { +final class BitArray { private final long[] data; private long bitCount; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 81772fcea0ec2..c0b425e729595 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -22,16 +22,10 @@ import java.io.OutputStream; /** - * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether - * an element is a member of a set. It returns false when the element is definitely not in the - * set, returns true when the element is probably in the set. - * - * Internally a Bloom filter is initialized with 2 information: how many space to use(number of - * bits) and how many hash values to calculate for each record. To get as lower false positive - * probability as possible, user should call {@link BloomFilter#create} to automatically pick a - * best combination of these 2 parameters. - * - * Currently the following data types are supported: + * A Bloom filter is a space-efficient probabilistic data structure that offers an approximate + * containment test with one-sided error: if it claims that an item is contained in it, this + * might be in error, but if it claims that an item is not contained in it, then this is + * definitely true. Currently supported data types include: *
        *
      • {@link Byte}
      • *
      • {@link Short}
      • @@ -39,14 +33,17 @@ *
      • {@link Long}
      • *
      • {@link String}
      • *
      + * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu + * not actually been put in the {@code BloomFilter}. * - * The implementation is largely based on the {@code BloomFilter} class from guava. + * The implementation is largely based on the {@code BloomFilter} class from Guava. */ public abstract class BloomFilter { public enum Version { /** - * {@code BloomFilter} binary format version 1 (all values written in big-endian order): + * {@code BloomFilter} binary format version 1. All values written in big-endian order: *
        *
      • Version number, always 1 (32 bit)
      • *
      • Number of hash functions (32 bit)
      • @@ -68,14 +65,13 @@ int getVersionNumber() { } /** - * Returns the false positive probability, i.e. the probability that - * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that - * has not actually been put in the {@code BloomFilter}. + * Returns the probability that {@linkplain #mightContain(Object)} erroneously return {@code true} + * for an object that has not actually been put in the {@code BloomFilter}. * - *

        Ideally, this number should be close to the {@code fpp} parameter - * passed in to create this bloom filter, or smaller. If it is - * significantly higher, it is usually the case that too many elements (more than - * expected) have been put in the {@code BloomFilter}, degenerating it. + * Ideally, this number should be close to the {@code fpp} parameter passed in + * {@linkplain #create(long, double)}, or smaller. If it is significantly higher, it is usually + * the case that too many items (more than expected) have been put in the {@code BloomFilter}, + * degenerating it. */ public abstract double expectedFpp(); @@ -85,8 +81,8 @@ int getVersionNumber() { public abstract long bitSize(); /** - * Puts an element into this {@code BloomFilter}. Ensures that subsequent invocations of - * {@link #mightContain(Object)} with the same element will always return {@code true}. + * Puts an item into this {@code BloomFilter}. Ensures that subsequent invocations of + * {@linkplain #mightContain(Object)} with the same item will always return {@code true}. * * @return true if the bloom filter's bits changed as a result of this operation. If the bits * changed, this is definitely the first time {@code object} has been added to the @@ -98,19 +94,19 @@ int getVersionNumber() { public abstract boolean put(Object item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string. + * A specialized variant of {@link #put(Object)} that only supports {@code String} items. */ - public abstract boolean putString(String str); + public abstract boolean putString(String item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put long. + * A specialized variant of {@link #put(Object)} that only supports {@code long} items. */ - public abstract boolean putLong(long l); + public abstract boolean putLong(long item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put byte array. + * A specialized variant of {@link #put(Object)} that only supports byte array items. */ - public abstract boolean putBinary(byte[] bytes); + public abstract boolean putBinary(byte[] item); /** * Determines whether a given bloom filter is compatible with this bloom filter. For two @@ -137,38 +133,36 @@ int getVersionNumber() { public abstract boolean mightContain(Object item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8 - * string. + * A specialized variant of {@link #mightContain(Object)} that only tests {@code String} items. */ - public abstract boolean mightContainString(String str); + public abstract boolean mightContainString(String item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long. + * A specialized variant of {@link #mightContain(Object)} that only tests {@code long} items. */ - public abstract boolean mightContainLong(long l); + public abstract boolean mightContainLong(long item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte - * array. + * A specialized variant of {@link #mightContain(Object)} that only tests byte array items. */ - public abstract boolean mightContainBinary(byte[] bytes); + public abstract boolean mightContainBinary(byte[] item); /** - * Writes out this {@link BloomFilter} to an output stream in binary format. - * It is the caller's responsibility to close the stream. + * Writes out this {@link BloomFilter} to an output stream in binary format. It is the caller's + * responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** - * Reads in a {@link BloomFilter} from an input stream. - * It is the caller's responsibility to close the stream. + * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close + * the stream. */ public static BloomFilter readFrom(InputStream in) throws IOException { return BloomFilterImpl.readFrom(in); } /** - * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the + * Computes the optimal k (number of hashes per item inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. * * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula. @@ -197,21 +191,31 @@ private static long optimalNumOfBits(long n, double p) { static final double DEFAULT_FPP = 0.03; /** - * Creates a {@link BloomFilter} with given {@code expectedNumItems} and the default {@code fpp}. + * Creates a {@link BloomFilter} with the expected number of insertions and a default expected + * false positive probability of 3%. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. */ public static BloomFilter create(long expectedNumItems) { return create(expectedNumItems, DEFAULT_FPP); } /** - * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code fpp}, it will pick - * an optimal {@code numBits} and {@code numHashFunctions} for the bloom filter. + * Creates a {@link BloomFilter} with the expected number of insertions and expected false + * positive probability. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. */ public static BloomFilter create(long expectedNumItems, double fpp) { - assert fpp > 0.0 : "False positive probability must be > 0.0"; - assert fpp < 1.0 : "False positive probability must be < 1.0"; - long numBits = optimalNumOfBits(expectedNumItems, fpp); - return create(expectedNumItems, numBits); + if (fpp <= 0D || fpp >= 1D) { + throw new IllegalArgumentException( + "False positive probability must be within range (0.0, 1.0)" + ); + } + + return create(expectedNumItems, optimalNumOfBits(expectedNumItems, fpp)); } /** @@ -219,9 +223,14 @@ public static BloomFilter create(long expectedNumItems, double fpp) { * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter. */ public static BloomFilter create(long expectedNumItems, long numBits) { - assert expectedNumItems > 0 : "Expected insertions must be > 0"; - assert numBits > 0 : "number of bits must be > 0"; - int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits); - return new BloomFilterImpl(numHashFunctions, numBits); + if (expectedNumItems <= 0) { + throw new IllegalArgumentException("Expected insertions must be positive"); + } + + if (numBits <= 0) { + throw new IllegalArgumentException("Number of bits must be positive"); + } + + return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 35107e0b389d7..92c28bcb56a5a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,9 +19,10 @@ import java.io.*; -public class BloomFilterImpl extends BloomFilter implements Serializable { +class BloomFilterImpl extends BloomFilter implements Serializable { private int numHashFunctions; + private BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { @@ -77,14 +78,14 @@ public boolean put(Object item) { } @Override - public boolean putString(String str) { - return putBinary(Utils.getBytesFromUTF8String(str)); + public boolean putString(String item) { + return putBinary(Utils.getBytesFromUTF8String(item)); } @Override - public boolean putBinary(byte[] bytes) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + public boolean putBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); long bitSize = bits.bitSize(); boolean bitsChanged = false; @@ -100,14 +101,14 @@ public boolean putBinary(byte[] bytes) { } @Override - public boolean mightContainString(String str) { - return mightContainBinary(Utils.getBytesFromUTF8String(str)); + public boolean mightContainString(String item) { + return mightContainBinary(Utils.getBytesFromUTF8String(item)); } @Override - public boolean mightContainBinary(byte[] bytes) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + public boolean mightContainBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { @@ -124,14 +125,14 @@ public boolean mightContainBinary(byte[] bytes) { } @Override - public boolean putLong(long l) { + public boolean putLong(long item) { // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. // Note that `CountMinSketch` use a different strategy, it hash the input long element with // every i to produce n hash values. // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? - int h1 = Murmur3_x86_32.hashLong(l, 0); - int h2 = Murmur3_x86_32.hashLong(l, h1); + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); long bitSize = bits.bitSize(); boolean bitsChanged = false; @@ -147,9 +148,9 @@ public boolean putLong(long l) { } @Override - public boolean mightContainLong(long l) { - int h1 = Murmur3_x86_32.hashLong(l, 0); - int h2 = Murmur3_x86_32.hashLong(l, h1); + public boolean mightContainLong(long item) { + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { @@ -197,7 +198,7 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep throw new IncompatibleMergeException("Cannot merge null bloom filter"); } - if (!(other instanceof BloomFilter)) { + if (!(other instanceof BloomFilterImpl)) { throw new IncompatibleMergeException( "Cannot merge bloom filter of class " + other.getClass().getName() ); @@ -211,7 +212,8 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep if (this.numHashFunctions != that.numHashFunctions) { throw new IncompatibleMergeException( - "Cannot merge bloom filters with different number of hash functions"); + "Cannot merge bloom filters with different number of hash functions" + ); } this.bits.putAll(that.bits); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index f0aac5bb00dfb..48f98680f48ca 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -22,7 +22,7 @@ import java.io.OutputStream; /** - * A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in + * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in * sub-linear space. Currently, supported data types include: *

          *
        • {@link Byte}
        • @@ -31,8 +31,7 @@ *
        • {@link Long}
        • *
        • {@link String}
        • *
        - * Each {@link CountMinSketch} is initialized with a random seed, and a pair - * of parameters: + * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters: *
          *
        1. relative error (or {@code eps}), and *
        2. confidence (or {@code delta}) @@ -49,16 +48,13 @@ *
        3. {@code w = ceil(-log(1 - confidence) / log(2))}
        4. *
      * - * See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details, - * including proofs of the estimates and error bounds used in this implementation. - * * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { public enum Version { /** - * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): + * {@code CountMinSketch} binary format version 1. All values written in big-endian order: *
        *
      • Version number, always 1 (32 bit)
      • *
      • Total count of added items (64 bit)
      • @@ -172,14 +168,14 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException; /** - * Writes out this {@link CountMinSketch} to an output stream in binary format. - * It is the caller's responsibility to close the stream. + * Writes out this {@link CountMinSketch} to an output stream in binary format. It is the caller's + * responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** - * Reads in a {@link CountMinSketch} from an input stream. - * It is the caller's responsibility to close the stream. + * Reads in a {@link CountMinSketch} from an input stream. It is the caller's responsibility to + * close the stream. */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); @@ -188,6 +184,10 @@ public static CountMinSketch readFrom(InputStream in) throws IOException { /** * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random * {@code seed}. + * + * @param depth depth of the Count-min Sketch, must be positive + * @param width width of the Count-min Sketch, must be positive + * @param seed random seed */ public static CountMinSketch create(int depth, int width, int seed) { return new CountMinSketchImpl(depth, width, seed); @@ -196,6 +196,10 @@ public static CountMinSketch create(int depth, int width, int seed) { /** * Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence}, * and random {@code seed}. + * + * @param eps relative error, must be positive + * @param confidence confidence, must be positive and less than 1.0 + * @param seed random seed */ public static CountMinSketch create(double eps, double confidence, int seed) { return new CountMinSketchImpl(eps, confidence, seed); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index c0631c6778df4..2acbb247b13cd 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -42,6 +42,10 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { private CountMinSketchImpl() {} CountMinSketchImpl(int depth, int width, int seed) { + if (depth <= 0 || width <= 0) { + throw new IllegalArgumentException("Depth and width must be both positive"); + } + this.depth = depth; this.width = width; this.eps = 2.0 / width; @@ -50,6 +54,14 @@ private CountMinSketchImpl() {} } CountMinSketchImpl(double eps, double confidence, int seed) { + if (eps <= 0D) { + throw new IllegalArgumentException("Relative error must be positive"); + } + + if (confidence <= 0D || confidence >= 1D) { + throw new IllegalArgumentException("Confidence must be within range (0.0, 1.0)"); + } + // 2/w = eps ; w = 2/eps // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) this.eps = eps; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java index a6b33313035b0..feb601d44f39d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java @@ -19,7 +19,7 @@ import java.io.UnsupportedEncodingException; -public class Utils { +class Utils { public static byte[] getBytesFromUTF8String(String str) { try { return str.getBytes("utf-8"); From e38b0baa38c6894335f187eaa4c8ea5c02d4563b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 13:45:03 -0800 Subject: [PATCH 020/107] [SPARK-13055] SQLHistoryListener throws ClassCastException This is an existing issue uncovered recently by #10835. The reason for the exception was because the `SQLHistoryListener` gets all sorts of accumulators, not just the ones that represent SQL metrics. For example, the listener gets the `internal.metrics.shuffleRead.remoteBlocksFetched`, which is an Int, then it proceeds to cast the Int to a Long, which fails. The fix is to mark accumulators representing SQL metrics using some internal metadata. Then we can identify which ones are SQL metrics and only process those in the `SQLHistoryListener`. Author: Andrew Or Closes #10971 from andrewor14/fix-sql-history. --- .../scala/org/apache/spark/Accumulable.scala | 8 ++++ .../apache/spark/executor/TaskMetrics.scala | 4 +- .../spark/scheduler/AccumulableInfo.scala | 5 ++- .../apache/spark/scheduler/DAGScheduler.scala | 7 +--- .../org/apache/spark/util/JsonProtocol.scala | 6 ++- .../spark/executor/TaskMetricsSuite.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 14 ++----- .../spark/scheduler/TaskSetManagerSuite.scala | 8 +--- .../apache/spark/util/JsonProtocolSuite.scala | 16 +++++--- .../sql/execution/metric/SQLMetrics.scala | 21 +++++++++- .../spark/sql/execution/ui/SQLListener.scala | 23 +++++++---- .../execution/metric/SQLMetricsSuite.scala | 24 +++++++++++- .../sql/execution/ui/SQLListenerSuite.scala | 38 ++++++++++++++++++- 13 files changed, 133 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala index 52f572b63fa95..601b503d12c7e 100644 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -22,6 +22,7 @@ import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable import scala.reflect.ClassTag +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils @@ -187,6 +188,13 @@ class Accumulable[R, T] private ( */ private[spark] def setValueAny(newValue: Any): Unit = { setValue(newValue.asInstanceOf[R]) } + /** + * Create an [[AccumulableInfo]] representation of this [[Accumulable]] with the provided values. + */ + private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, internal, countFailedValues) + } + // Called by Java when deserializing an object private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 8d10bf588ef1f..0a6ebcb3e0293 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -323,8 +323,8 @@ class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { * field is always empty, since this represents the partial updates recorded in this task, * not the aggregated value across multiple tasks. */ - def accumulatorUpdates(): Seq[AccumulableInfo] = accums.map { a => - new AccumulableInfo(a.id, a.name, Some(a.localValue), None, a.isInternal, a.countFailedValues) + def accumulatorUpdates(): Seq[AccumulableInfo] = { + accums.map { a => a.toInfo(Some(a.localValue), None) } } // If we are reconstructing this TaskMetrics on the driver, some metrics may already be set. diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 9d45fff9213c6..cedacad44afec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -35,6 +35,7 @@ import org.apache.spark.annotation.DeveloperApi * @param value total accumulated value so far, maybe None if used on executors to describe a task * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed + * @param metadata internal metadata associated with this accumulator, if any */ @DeveloperApi case class AccumulableInfo private[spark] ( @@ -43,7 +44,9 @@ case class AccumulableInfo private[spark] ( update: Option[Any], // represents a partial update within a task value: Option[Any], private[spark] val internal: Boolean, - private[spark] val countFailedValues: Boolean) + private[spark] val countFailedValues: Boolean, + // TODO: use this to identify internal task metrics instead of encoding it in the name + private[spark] val metadata: Option[String] = None) /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 897479b50010d..ee0b8a1c95fd8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1101,11 +1101,8 @@ class DAGScheduler( acc ++= partialValue // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name - stage.latestInfo.accumulables(id) = new AccumulableInfo( - id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues) - event.taskInfo.accumulables += new AccumulableInfo( - id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues) + stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) + event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value)) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index dc8070cf8aad3..a2487eeb0483a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -290,7 +290,8 @@ private[spark] object JsonProtocol { ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~ ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~ ("Internal" -> accumulableInfo.internal) ~ - ("Count Failed Values" -> accumulableInfo.countFailedValues) + ("Count Failed Values" -> accumulableInfo.countFailedValues) ~ + ("Metadata" -> accumulableInfo.metadata) } /** @@ -728,7 +729,8 @@ private[spark] object JsonProtocol { val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) - new AccumulableInfo(id, name, update, value, internal, countFailedValues) + val metadata = (json \ "Metadata").extractOpt[String] + new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } /** diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 15be0b194ed8e..67c4595ed1923 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -551,8 +551,6 @@ private[spark] object TaskMetricsSuite extends Assertions { * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ - def makeInfo(a: Accumulable[_, _]): AccumulableInfo = { - new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues) - } + def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d9c71ec2eae7b..62972a0738211 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1581,12 +1581,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(Accumulators.get(acc1.id).isDefined) assert(Accumulators.get(acc2.id).isDefined) assert(Accumulators.get(acc3.id).isDefined) - val accInfo1 = new AccumulableInfo( - acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false) - val accInfo2 = new AccumulableInfo( - acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false) - val accInfo3 = new AccumulableInfo( - acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false) + val accInfo1 = acc1.toInfo(Some(15L), None) + val accInfo2 = acc2.toInfo(Some(13L), None) + val accInfo3 = acc3.toInfo(Some(18L), None) val accumUpdates = Seq(accInfo1, accInfo2, accInfo3) val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates) submit(new MyRDD(sc, 1, Nil), Array(0)) @@ -1954,10 +1951,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo], taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = { val accumUpdates = reason match { - case Success => - task.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues) - } + case Success => task.initialAccumulators.map { a => a.toInfo(Some(a.zero), None) } case ef: ExceptionFailure => ef.accumUpdates case _ => Seq.empty[AccumulableInfo] } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index a2e74365641a6..2c99dd5afb32e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -165,9 +165,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) - } + val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => a.toInfo(Some(0L), None) } // Offer a host with NO_PREF as the constraint, // we should get a nopref task immediately since that's what we only have @@ -186,9 +184,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task => - task.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) - } + task.initialAccumulators.map { a => a.toInfo(Some(0L), None) } } // First three offers should all find tasks diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 57021d1d3d528..48951c3168032 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -374,15 +374,18 @@ class JsonProtocolSuite extends SparkFunSuite { test("AccumulableInfo backward compatibility") { // "Internal" property of AccumulableInfo was added in 1.5.1 val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true) - val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) - .removeField({ _._1 == "Internal" }) + val accumulableInfoJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) + val oldJson = accumulableInfoJson.removeField({ _._1 == "Internal" }) val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson) assert(!oldInfo.internal) // "Count Failed Values" property of AccumulableInfo was added in 2.0.0 - val oldJson2 = JsonProtocol.accumulableInfoToJson(accumulableInfo) - .removeField({ _._1 == "Count Failed Values" }) + val oldJson2 = accumulableInfoJson.removeField({ _._1 == "Count Failed Values" }) val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2) assert(!oldInfo2.countFailedValues) + // "Metadata" property of AccumulableInfo was added in 2.0.0 + val oldJson3 = accumulableInfoJson.removeField({ _._1 == "Metadata" }) + val oldInfo3 = JsonProtocol.accumulableInfoFromJson(oldJson3) + assert(oldInfo3.metadata.isEmpty) } test("ExceptionFailure backward compatibility: accumulator updates") { @@ -820,9 +823,10 @@ private[spark] object JsonProtocolSuite extends Assertions { private def makeAccumulableInfo( id: Int, internal: Boolean = false, - countFailedValues: Boolean = false): AccumulableInfo = + countFailedValues: Boolean = false, + metadata: Option[String] = None): AccumulableInfo = new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"), - internal, countFailedValues) + internal, countFailedValues, metadata) /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 950dc7816241f..6b43d273fefde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.metric import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext} +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.util.Utils /** @@ -27,9 +28,16 @@ import org.apache.spark.util.Utils * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. */ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( - name: String, val param: SQLMetricParam[R, T]) + name: String, + val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name), internal = true) { + // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later + override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, Some(name), update, value, isInternal, countFailedValues, + Some(SQLMetrics.ACCUM_IDENTIFIER)) + } + def reset(): Unit = { this.value = param.zero } @@ -73,6 +81,14 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr // Although there is a boxing here, it's fine because it's only called in SQLListener override def value: Long = _value + + // Needed for SQLListenerSuite + override def equals(other: Any): Boolean = { + other match { + case o: LongSQLMetricValue => value == o.value + case _ => false + } + } } /** @@ -126,6 +142,9 @@ private object StaticsLongSQLMetricParam extends LongSQLMetricParam( private[sql] object SQLMetrics { + // Identifier for distinguishing SQL metrics from other accumulators + private[sql] val ACCUM_IDENTIFIER = "sql" + private def createLongMetric( sc: SparkContext, name: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 544606f1168b6..835e7ba6c5168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -23,7 +23,7 @@ import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric._ import org.apache.spark.ui.SparkUI @DeveloperApi @@ -314,14 +314,17 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } + +/** + * A [[SQLListener]] for rendering the SQL UI in the history server. + */ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) extends SQLListener(conf) { private var sqlTabAttached = false - override def onExecutorMetricsUpdate( - executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - // Do nothing + override def onExecutorMetricsUpdate(u: SparkListenerExecutorMetricsUpdate): Unit = { + // Do nothing; these events are not logged } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { @@ -329,9 +332,15 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.map { a => - val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L)) - a.copy(update = Some(newValue)) + taskEnd.taskInfo.accumulables.flatMap { a => + // Filter out accumulators that are not SQL metrics + // For now we assume all SQL metrics are Long's that have been JSON serialized as String's + if (a.metadata.exists(_ == SQLMetrics.ACCUM_IDENTIFIER)) { + val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L)) + Some(a.copy(update = Some(newValue))) + } else { + None + } }, finishTask = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82f6811503c23..2260e4870299a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{JsonProtocol, Utils} class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { @@ -356,6 +356,28 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("metrics can be loaded by history server") { + val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam) + metric += 10L + val metricInfo = metric.toInfo(Some(metric.localValue), None) + metricInfo.update match { + case Some(v: LongSQLMetricValue) => assert(v.value === 10L) + case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}") + case _ => fail("metric update is missing") + } + assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + // After serializing to JSON, the original value type is lost, but we can still + // identify that it's a SQL metric from the metadata + val metricInfoJson = JsonProtocol.accumulableInfoToJson(metricInfo) + val metricInfoDeser = JsonProtocol.accumulableInfoFromJson(metricInfoJson) + metricInfoDeser.update match { + case Some(v: String) => assert(v.toLong === 10L) + case Some(v) => fail(s"deserialized metric value was not a string: ${v.getClass.getName}") + case _ => fail("deserialized metric update is missing") + } + assert(metricInfoDeser.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + } + } private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 2c408c8878470..085e4a49a57e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.ui.SparkUI class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ @@ -335,8 +336,43 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } + test("SPARK-13055: history listener only tracks SQL metrics") { + val listener = new SQLHistoryListener(sparkContext.conf, mock(classOf[SparkUI])) + // We need to post other events for the listener to track our accumulators. + // These are largely just boilerplate unrelated to what we're trying to test. + val df = createTestDataFrame + val executionStart = SparkListenerSQLExecutionStart( + 0, "", "", "", SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0) + val stageInfo = createStageInfo(0, 0) + val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), createProperties(0)) + val stageSubmitted = SparkListenerStageSubmitted(stageInfo) + // This task has both accumulators that are SQL metrics and accumulators that are not. + // The listener should only track the ones that are actually SQL metrics. + val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella") + val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball") + val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None) + val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None) + val taskInfo = createTaskInfo(0, 0) + taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo) + val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null) + listener.onOtherEvent(executionStart) + listener.onJobStart(jobStart) + listener.onStageSubmitted(stageSubmitted) + // Before SPARK-13055, this throws ClassCastException because the history listener would + // assume that the accumulator value is of type Long, but this may not be true for + // accumulators that are not SQL metrics. + listener.onTaskEnd(taskEnd) + val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { stageMetrics => + stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates) + } + // Listener tracks only SQL metrics, not other accumulators + assert(trackedAccums.size === 1) + assert(trackedAccums.head === sqlMetricInfo) + } + } + class SQLListenerMemoryLeakSuite extends SparkFunSuite { test("no memory leak") { From 2cbc412821641cf9446c0621ffa1976bd7fc4fa1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 29 Jan 2016 16:57:34 -0800 Subject: [PATCH 021/107] [SPARK-13076][SQL] Rename ClientInterface -> HiveClient And ClientWrapper -> HiveClientImpl. I have some followup pull requests to introduce a new internal catalog, and I think this new naming reflects better the functionality of the two classes. Author: Reynold Xin Closes #10981 from rxin/SPARK-13076. --- .../apache/spark/sql/execution/commands.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 10 +++++----- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- ...ClientInterface.scala => HiveClient.scala} | 10 +++++----- ...ientWrapper.scala => HiveClientImpl.scala} | 20 +++++++++---------- .../spark/sql/hive/client/HiveShim.scala | 5 ++--- .../hive/client/IsolatedClientLoader.scala | 18 ++++++++--------- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 ++-- .../apache/spark/sql/hive/test/TestHive.scala | 4 ++-- .../spark/sql/hive/client/VersionsSuite.scala | 4 ++-- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 12 files changed, 41 insertions(+), 42 deletions(-) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/client/{ClientInterface.scala => HiveClient.scala} (95%) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/client/{ClientWrapper.scala => HiveClientImpl.scala} (97%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 703e4643cbd25..c6adb583f931b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -404,7 +404,7 @@ case class DescribeFunction( result } - case None => Seq(Row(s"Function: $functionName is not found.")) + case None => Seq(Row(s"Function: $functionName not found.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51a50c1fa30e4..2b821c1056f56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -84,7 +84,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "Extended Usage") checkExistence(sql("describe functioN abcadf"), true, - "Function: abcadf is not found.") + "Function: abcadf not found.") } test("SPARK-6743: no columns from cache") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1797ea54f2501..05863ae18350d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -79,8 +79,8 @@ class HiveContext private[hive]( sc: SparkContext, cacheManager: CacheManager, listener: SQLListener, - @transient private val execHive: ClientWrapper, - @transient private val metaHive: ClientInterface, + @transient private val execHive: HiveClientImpl, + @transient private val metaHive: HiveClient, isRootContext: Boolean) extends SQLContext(sc, cacheManager, listener, isRootContext) with Logging { self => @@ -193,7 +193,7 @@ class HiveContext private[hive]( * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. */ @transient - protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) { + protected[hive] lazy val executionHive: HiveClientImpl = if (execHive != null) { execHive } else { logInfo(s"Initializing execution hive, version $hiveExecutionVersion") @@ -203,7 +203,7 @@ class HiveContext private[hive]( config = newTemporaryConfiguration(useInMemoryDerby = true), isolationOn = false, baseClassLoader = Utils.getContextOrSparkClassLoader) - loader.createClient().asInstanceOf[ClientWrapper] + loader.createClient().asInstanceOf[HiveClientImpl] } /** @@ -222,7 +222,7 @@ class HiveContext private[hive]( * in the hive-site.xml file. */ @transient - protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) { + protected[hive] lazy val metadataHive: HiveClient = if (metaHive != null) { metaHive } else { val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 848aa4ec6fe56..61d0d6759ff72 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -96,7 +96,7 @@ private[hive] object HiveSerDe { } } -private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) +private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) extends Catalog with Logging { val conf = hive.conf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala similarity index 95% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 4eec3fef7408b..f681cc67041a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -60,9 +60,9 @@ private[hive] case class HiveTable( viewText: Option[String] = None) { @transient - private[client] var client: ClientInterface = _ + private[client] var client: HiveClient = _ - private[client] def withClient(ci: ClientInterface): this.type = { + private[client] def withClient(ci: HiveClient): this.type = { client = ci this } @@ -85,7 +85,7 @@ private[hive] case class HiveTable( * internal and external classloaders for a given version of Hive and thus must expose only * shared classes. */ -private[hive] trait ClientInterface { +private[hive] trait HiveClient { /** Returns the Hive Version of this client. */ def version: HiveVersion @@ -184,8 +184,8 @@ private[hive] trait ClientInterface { /** Add a jar into class loader */ def addJar(path: String): Unit - /** Return a ClientInterface as new session, that will share the class loader and Hive client */ - def newSession(): ClientInterface + /** Return a [[HiveClient]] as new session, that will share the class loader and Hive client */ + def newSession(): HiveClient /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */ def withHiveState[A](f: => A): A diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala similarity index 97% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 5307e924e7e55..cf1ff55c96fc9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -44,8 +44,8 @@ import org.apache.spark.util.{CircularBuffer, Utils} * A class that wraps the HiveClient and converts its responses to externally visible classes. * Note that this class is typically loaded with an internal classloader for each instantiation, * allowing it to interact directly with a specific isolated version of Hive. Loading this class - * with the isolated classloader however will result in it only being visible as a ClientInterface, - * not a ClientWrapper. + * with the isolated classloader however will result in it only being visible as a [[HiveClient]], + * not a [[HiveClientImpl]]. * * This class needs to interact with multiple versions of Hive, but will always be compiled with * the 'native', execution version of Hive. Therefore, any places where hive breaks compatibility @@ -55,14 +55,14 @@ import org.apache.spark.util.{CircularBuffer, Utils} * @param config a collection of configuration options that will be added to the hive conf before * opening the hive client. * @param initClassLoader the classloader used when creating the `state` field of - * this ClientWrapper. + * this [[HiveClientImpl]]. */ -private[hive] class ClientWrapper( +private[hive] class HiveClientImpl( override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader, val clientLoader: IsolatedClientLoader) - extends ClientInterface + extends HiveClient with Logging { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @@ -77,7 +77,7 @@ private[hive] class ClientWrapper( case hive.v1_2 => new Shim_v1_2() } - // Create an internal session state for this ClientWrapper. + // Create an internal session state for this HiveClientImpl. val state = { val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. @@ -160,7 +160,7 @@ private[hive] class ClientWrapper( case e: Exception if causedByThrift(e) => caughtException = e logWarning( - "HiveClientWrapper got thrift exception, destroying client and retrying " + + "HiveClient got thrift exception, destroying client and retrying " + s"(${retryLimit - numTries} tries remaining)", e) clientLoader.cachedHive = null Thread.sleep(retryDelayMillis) @@ -199,7 +199,7 @@ private[hive] class ClientWrapper( */ def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader - // Set the thread local metastore client to the client associated with this ClientWrapper. + // Set the thread local metastore client to the client associated with this HiveClientImpl. Hive.set(client) // The classloader in clientLoader could be changed after addJar, always use the latest // classloader @@ -521,8 +521,8 @@ private[hive] class ClientWrapper( runSqlHive(s"ADD JAR $path") } - def newSession(): ClientWrapper = { - clientLoader.createClient().asInstanceOf[ClientWrapper] + def newSession(): HiveClientImpl = { + clientLoader.createClient().asInstanceOf[HiveClientImpl] } def reset(): Unit = withHiveState { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index ca636b0265d41..70c10be25be9f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegralType, StringType} /** - * A shim that defines the interface between ClientWrapper and the underlying Hive library used to - * talk to the metastore. Each Hive version has its own implementation of this class, defining + * A shim that defines the interface between [[HiveClientImpl]] and the underlying Hive library used + * to talk to the metastore. Each Hive version has its own implementation of this class, defining * version-specific version of needed functions. * * The guideline for writing shims is: @@ -52,7 +52,6 @@ private[client] sealed abstract class Shim { /** * Set the current SessionState to the given SessionState. Also, set the context classloader of * the current thread to the one set in the HiveConf of this given `state`. - * @param state */ def setCurrentSessionState(state: SessionState): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 010051d255fdc..dca7396ee1ab4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -124,15 +124,15 @@ private[hive] object IsolatedClientLoader extends Logging { } /** - * Creates a Hive `ClientInterface` using a classloader that works according to the following rules: + * Creates a [[HiveClient]] using a classloader that works according to the following rules: * - Shared classes: Java, Scala, logging, and Spark classes are delegated to `baseClassLoader` - * allowing the results of calls to the `ClientInterface` to be visible externally. + * allowing the results of calls to the [[HiveClient]] to be visible externally. * - Hive classes: new instances are loaded from `execJars`. These classes are not * accessible externally due to their custom loading. - * - ClientWrapper: a new copy is created for each instance of `IsolatedClassLoader`. + * - [[HiveClientImpl]]: a new copy is created for each instance of `IsolatedClassLoader`. * This new instance is able to see a specific version of hive without using reflection. Since * this is a unique instance, it is not visible externally other than as a generic - * `ClientInterface`, unless `isolationOn` is set to `false`. + * [[HiveClient]], unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures * that are not compatible across versions. @@ -179,7 +179,7 @@ private[hive] class IsolatedClientLoader( /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = - name.startsWith(classOf[ClientWrapper].getName) || + name.startsWith(classOf[HiveClientImpl].getName) || name.startsWith(classOf[Shim].getName) || barrierPrefixes.exists(name.startsWith) @@ -233,9 +233,9 @@ private[hive] class IsolatedClientLoader( } /** The isolated client interface to Hive. */ - private[hive] def createClient(): ClientInterface = { + private[hive] def createClient(): HiveClient = { if (!isolationOn) { - return new ClientWrapper(version, config, baseClassLoader, this) + return new HiveClientImpl(version, config, baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -244,10 +244,10 @@ private[hive] class IsolatedClientLoader( try { classLoader - .loadClass(classOf[ClientWrapper].getName) + .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head .newInstance(version, config, classLoader, this) - .asInstanceOf[ClientInterface] + .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => if (e.getCause().isInstanceOf[NoClassDefFoundError]) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 56cab1aee89df..d5ed838ca4b1a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -38,13 +38,13 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.hive.client.ClientWrapper +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types._ private[hive] class HiveFunctionRegistry( underlying: analysis.FunctionRegistry, - executionHive: ClientWrapper) + executionHive: HiveClientImpl) extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a33223af24370..246108e0d0e11 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.client.ClientWrapper +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -458,7 +458,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive) } -private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper) +private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: HiveClientImpl) extends HiveFunctionRegistry(fr, client) { private val removedFunctions = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index ff10a251f3b45..1344a2cc4bd37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.Utils /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * A simple set of tests that call the methods of a [[HiveClient]], loading different version * of hive from maven central. These tests are simple in that they are mostly just testing to make * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. @@ -101,7 +101,7 @@ class VersionsSuite extends SparkFunSuite with Logging { private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") - private var client: ClientInterface = null + private var client: HiveClient = null versions.foreach { version => test(s"$version: create client") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 0d62d799c8dce..1ada2e325bda6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -199,7 +199,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { "Extended Usage") checkExistence(sql("describe functioN abcadf"), true, - "Function: abcadf is not found.") + "Function: abcadf not found.") checkExistence(sql("describe functioN `~`"), true, "Function: ~", From e6ceac49a311faf3413acda57a6612fe806adf90 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 17:59:41 -0800 Subject: [PATCH 022/107] [SPARK-13096][TEST] Fix flaky verifyPeakExecutionMemorySet Previously we would assert things before all events are guaranteed to have been processed. To fix this, just block until all events are actually processed, i.e. until the listener queue is empty. https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/79/testReport/junit/org.apache.spark.util.collection/ExternalAppendOnlyMapSuite/spilling/ Author: Andrew Or Closes #10990 from andrewor14/accum-suite-less-flaky. --- core/src/test/scala/org/apache/spark/AccumulatorSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 11c97d7d9a447..b8f2b96d7088d 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -307,6 +307,8 @@ private[spark] object AccumulatorSuite { val listener = new SaveInfoListener sc.addSparkListener(listener) testBody + // wait until all events have been processed before proceeding to assert things + sc.listenerBus.waitUntilEmpty(10 * 1000) val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) val isSet = accums.exists { a => a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L) From 70e69fc4dd619654f5d24b8b84f6a94f7705c59b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 18:00:49 -0800 Subject: [PATCH 023/107] [SPARK-13088] Fix DAG viz in latest version of chrome Apparently chrome removed `SVGElement.prototype.getTransformToElement`, which is used by our JS library dagre-d3 when creating edges. The real diff can be found here: https://github.com/andrewor14/dagre-d3/commit/7d6c0002e4c74b82a02c5917876576f71e215590, which is taken from the fix in the main repo: https://github.com/cpettitt/dagre-d3/commit/1ef067f1c6ad2e0980f6f0ca471bce998784b7b2 Upstream issue: https://github.com/cpettitt/dagre-d3/issues/202 Author: Andrew Or Closes #10986 from andrewor14/fix-dag-viz. --- .../org/apache/spark/ui/static/dagre-d3.min.js | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js index 2d9262b972a59..6fe8136c87ae0 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -1,4 +1,5 @@ -/* This is a custom version of dagre-d3 on top of v0.4.3. The full list of commits can be found at http://github.com/andrewor14/dagre-d3/ */!function(e){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=e();else if("function"==typeof define&&define.amd)define([],e);else{var f;"undefined"!=typeof window?f=window:"undefined"!=typeof global?f=global:"undefined"!=typeof self&&(f=self),f.dagreD3=e()}}(function(){var define,module,exports;return function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ -parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdx0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){ +normalize.run(g)});time(" parentDummyChains",function(){parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}; +})}function enterEdge(t,g,edge){var v=edge.v,w=edge.w;if(!g.hasEdge(v,w)){v=edge.w;w=edge.v}var vLabel=t.node(v),wLabel=t.node(w),tailLabel=vLabel,flip=false;if(vLabel.lim>wLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){ +return isObject(prototype)?nativeCreate(prototype):{}}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index Date: Fri, 29 Jan 2016 18:03:04 -0800 Subject: [PATCH 024/107] [SPARK-13071] Coalescing HadoopRDD overwrites existing input metrics This issue is causing tests to fail consistently in master with Hadoop 2.6 / 2.7. This is because for Hadoop 2.5+ we overwrite existing values of `InputMetrics#bytesRead` in each call to `HadoopRDD#compute`. In the case of coalesce, e.g. ``` sc.textFile(..., 4).coalesce(2).count() ``` we will call `compute` multiple times in the same task, overwriting `bytesRead` values from previous calls to `compute`. For a regression test, see `InputOutputMetricsSuite.input metrics for old hadoop with coalesce`. I did not add a new regression test because it's impossible without significant refactoring; there's a lot of existing duplicate code in this corner of Spark. This was caused by #10835. Author: Andrew Or Closes #10973 from andrewor14/fix-input-metrics-coalesce. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 7 ++++++- .../src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 7 ++++++- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 7 ++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 3204e6adceca2..e2ebd7f00d0d5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,6 +215,7 @@ class HadoopRDD[K, V]( // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.inputSplit.value match { @@ -230,9 +231,13 @@ class HadoopRDD[K, V]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 4d2816e335fe3..e71d3405c0ead 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -130,6 +130,7 @@ class NewHadoopRDD[K, V]( val conf = getConf val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes @@ -139,9 +140,13 @@ class NewHadoopRDD[K, V]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index edd87c2d8ed07..9703b16c86f90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -127,6 +127,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( val conf = getConf(isDriverSide = false) val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { @@ -142,9 +143,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } From e6a02c66d53f59ba2d5c1548494ae80a385f9f5c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 20:16:11 -0800 Subject: [PATCH 025/107] [SPARK-12914] [SQL] generate aggregation with grouping keys This PR add support for grouping keys for generated TungstenAggregate. Spilling and performance improvements for BytesToBytesMap will be done by followup PR. Author: Davies Liu Closes #10855 from davies/gen_keys. --- .../expressions/codegen/CodeGenerator.scala | 47 ++++ .../codegen/GenerateMutableProjection.scala | 27 +- .../sql/execution/BufferedRowIterator.java | 6 +- .../aggregate/TungstenAggregate.scala | 238 ++++++++++++++++-- .../BenchmarkWholeStageCodegen.scala | 119 ++++++++- .../execution/WholeStageCodegenSuite.scala | 9 + 6 files changed, 393 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb1f7..21f9198073d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,6 +55,20 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Add an object to `references`, create a class member to access it. + * + * Returns the name of class member. + */ + def addReferenceObj(name: String, obj: Any, className: String = null): String = { + val term = freshName(name) + val idx = references.length + references += obj + val clsName = Option(className).getOrElse(obj.getClass.getName) + addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + term + } + /** * Holding a list of generated columns as input of current operator, will be used by * BoundReference to generate code. @@ -198,6 +212,39 @@ class CodegenContext { } } + /** + * Update a column in MutableRow from ExprCode. + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (dataType.isInstanceOf[DecimalType]) { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + ${setColumn(row, dataType, ordinal, "null")}; + } + """ + } else { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + $row.setNullAt($ordinal); + } + """ + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + /** * Returns the name used in accessor and setter for a Java primitive type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b94b8..5b4dc8df8622b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val updates = validExpr.zip(index).map { case (e, i) => - if (e.nullable) { - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } else { - s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } - } else { - s""" - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - """ - } - + val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index b1bbb1da10a39..6acf70dbbad0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution; +import java.io.IOException; + import scala.collection.Iterator; import org.apache.spark.sql.catalyst.InternalRow; @@ -34,7 +36,7 @@ public class BufferedRowIterator { // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); - public boolean hasNext() { + public boolean hasNext() throws IOException { if (currentRow == null) { processNext(); } @@ -56,7 +58,7 @@ public void setInput(Iterator iter) { * * After it's called, if currentRow is still null, it means no more rows left. */ - protected void processNext() { + protected void processNext() throws IOException { if (input.hasNext()) { currentRow = input.next(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index ff2f38bfd9105..57db7262fdaf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -114,22 +116,38 @@ case class TungstenAggregate( } } + // all the mode of aggregate expressions + private val modes = aggregateExpressions.map(_.mode).distinct + override def supportCodegen: Boolean = { - groupingExpressions.isEmpty && - // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } - // The variables used as aggregation buffer - private var bufVars: Seq[ExprCode] = _ - - private val modes = aggregateExpressions.map(_.mode).distinct - override def upstream(): RDD[InternalRow] = { child.asInstanceOf[CodegenSupport].upstream() } protected override def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, input) + } else { + doConsumeWithKeys(ctx, input) + } + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -176,10 +194,10 @@ case class TungstenAggregate( (resultVars, resultVars.map(_.code).mkString("\n")) } - val doAgg = ctx.freshName("doAgg") + val doAgg = ctx.freshName("doAggregateWithoutKey") ctx.addNewFunction(doAgg, s""" - | private void $doAgg() { + | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | @@ -200,7 +218,7 @@ case class TungstenAggregate( """.stripMargin } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output @@ -212,7 +230,6 @@ case class TungstenAggregate( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) @@ -232,6 +249,199 @@ case class TungstenAggregate( """.stripMargin } + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) + private val bufferSchema = StructType.fromAttributes(bufferAttributes) + + // The name for HashMap + private var hashMapTerm: String = _ + + /** + * This is called by generated Java class, should be public. + */ + def createHashMap(): UnsafeFixedWidthAggregationMap = { + // create initialized aggregate buffer + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + + // create hashMap + new UnsafeFixedWidthAggregationMap( + initialBuffer, + bufferSchema, + groupingKeySchema, + TaskContext.get().taskMemoryManager(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes, + false // disable tracking of performance metrics + ) + } + + /** + * This is called by generated Java class, should be public. + */ + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + + /** + * Update peak execution memory, called in generated Java class. + */ + def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { + val mapMemory = hashMap.getPeakMemoryUsedBytes + val metrics = TaskContext.get().taskMetrics() + metrics.incPeakExecutionMemory(mapMemory) + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + ctx.INPUT_ROW = bufferTerm + val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + // evaluate the aggregation result + ctx.currentVars = bufferVars + val aggResults = declFunctions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttributes).gen(ctx) + } + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).gen(ctx) + } + s""" + ${keyVars.map(_.code).mkString("\n")} + ${bufferVars.map(_.code).mkString("\n")} + ${aggResults.map(_.code).mkString("\n")} + ${resultVars.map(_.code).mkString("\n")} + + ${consume(ctx, resultVars)} + """ + + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val joinerTerm = ctx.freshName("unsafeRowJoiner") + ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + s"$joinerTerm = $thisPlan.createUnsafeJoiner();") + val resultRow = ctx.freshName("resultRow") + s""" + UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + ${consume(ctx, null, resultRow)} + """ + + } else { + // generate result based on grouping key + ctx.INPUT_ROW = keyTerm + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).gen(ctx) + } + s""" + ${eval.map(_.code).mkString("\n")} + ${consume(ctx, eval)} + """ + } + + val doAgg = ctx.freshName("doAggregateWithKeys") + ctx.addNewFunction(doAgg, + s""" + private void $doAgg() throws java.io.IOException { + ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + + $iterTerm = $hashMapTerm.iterator(); + } + """) + + s""" + if (!$initAgg) { + $initAgg = true; + $doAgg(); + } + + // output the result + while ($iterTerm.next()) { + UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + $outputCode + } + + $thisPlan.updatePeakMemory($hashMapTerm); + $hashMapTerm.free(); + """ + } + + private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = { + + // create grouping key + ctx.currentVars = input + val keyCode = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val key = keyCode.value + val buffer = ctx.freshName("aggBuffer") + + // only have DeclarativeAggregate + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + + val inputAttr = bufferAttributes ++ child.output + ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input + ctx.INPUT_ROW = buffer + // TODO: support subexpression elimination + val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) + val updates = evals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) + } + + s""" + // generate grouping key + ${keyCode.code} + UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } + + // evaluate aggregate function + ${evals.map(_.code).mkString("\n")} + // update aggregate buffer + ${updates.mkString("\n")} + """ + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index c4aad398bfa54..2f09c8a114bc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -18,7 +18,12 @@ package org.apache.spark.sql.execution import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark /** @@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" */ class BenchmarkWholeStageCodegen extends SparkFunSuite { - def testWholeStage(values: Int): Unit = { - val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") - val sc = SparkContext.getOrCreate(conf) - val sqlContext = SQLContext.getOrCreate(sc) + lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + lazy val sc = SparkContext.getOrCreate(conf) + lazy val sqlContext = SQLContext.getOrCreate(sc) - val benchmark = new Benchmark("Single Int Column Scan", values) + def testWholeStage(values: Int): Unit = { + val benchmark = new Benchmark("rang/filter/aggregate", values) - benchmark.addCase("Without whole stage codegen") { iter => + benchmark.addCase("Without codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "false") sqlContext.range(values).filter("(id & 1) = 1").count() } - benchmark.addCase("With whole stage codegen") { iter => + benchmark.addCase("With codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "true") sqlContext.range(values).filter("(id & 1) = 1").count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- - Without whole stage codegen 7775.53 26.97 1.00 X - With whole stage codegen 342.15 612.94 22.73 X + Without codegen 7775.53 26.97 1.00 X + With codegen 342.15 612.94 22.73 X */ benchmark.run() } - ignore("benchmark") { - testWholeStage(1024 * 1024 * 200) + def testAggregateWithKey(values: Int): Unit = { + val benchmark = new Benchmark("Aggregate with keys", values) + + benchmark.addCase("Aggregate w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + benchmark.addCase(s"Aggregate w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Aggregate w/o codegen 4254.38 4.93 1.00 X + Aggregate w codegen 2661.45 7.88 1.60 X + */ + benchmark.run() + } + + def testBytesToBytesMap(values: Int): Unit = { + val benchmark = new Benchmark("BytesToBytesMap", values) + + benchmark.addCase("hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < values) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashUnsafeWords( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) + s += h + i += 1 + } + } + + Seq("off", "on").foreach { heap => + benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") + .set("spark.memory.offHeap.size", "102400000"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var i = 0 + while (i < values) { + key.setInt(0, i % 65536) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + if (loc.isDefined) { + value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset, + loc.getValueLength) + value.setInt(0, value.getInt(0) + 1) + i += 1 + } else { + loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + } + } + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + hash 662.06 79.19 1.00 X + BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X + BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X + */ + benchmark.run() + } + + test("benchmark") { + // testWholeStage(1024 * 1024 * 200) + // testAggregateWithKey(20 << 20) + // testBytesToBytesMap(1024 * 1024 * 50) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 300788c88ab2f..c2516509dfbbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -47,4 +47,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(9, 4.5))) } + + test("Aggregate with grouping keys should be included in WholeStageCodegen") { + val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + } } From dab246f7e4664d36073ec49d9df8a11c5e998cdb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Jan 2016 23:37:51 -0800 Subject: [PATCH 026/107] [SPARK-13098] [SQL] remove GenericInternalRowWithSchema This class is only used for serialization of Python DataFrame. However, we don't require internal row there, so `GenericRowWithSchema` can also do the job. Author: Wenchen Fan Closes #10992 from cloud-fan/python. --- .../spark/sql/catalyst/expressions/rows.scala | 12 ------------ .../org/apache/spark/sql/execution/python.scala | 13 +++++-------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 387d979484f2c..be6b2530ef39e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -233,18 +233,6 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override def copy(): GenericInternalRow = this } -/** - * This is used for serialization of Python DataFrame - */ -class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) - extends GenericInternalRow(values) { - - /** No-arg constructor for serialization. */ - protected def this() = this(null, null) - - def fieldIndex(name: String): Int = schema.fieldIndex(name) -} - class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index e3a016e18db87..bf62bb05c3d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -143,7 +143,7 @@ object EvaluatePython { values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) i += 1 } - new GenericInternalRowWithSchema(values, struct) + new GenericRowWithSchema(values, struct) case (a: ArrayData, array: ArrayType) => val values = new java.util.ArrayList[Any](a.numElements()) @@ -199,10 +199,7 @@ object EvaluatePython { case (c: Long, TimestampType) => c - case (c: String, StringType) => UTF8String.fromString(c) - case (c, StringType) => - // If we get here, c is not a string. Call toString on it. - UTF8String.fromString(c.toString) + case (c, StringType) => UTF8String.fromString(c.toString) case (c: String, BinaryType) => c.getBytes("utf-8") case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c @@ -263,11 +260,11 @@ object EvaluatePython { } /** - * Pickler for InternalRow + * Pickler for external row. */ private class RowPickler extends IObjectPickler { - private val cls = classOf[GenericInternalRowWithSchema] + private val cls = classOf[GenericRowWithSchema] // register this to Pickler and Unpickler def register(): Unit = { @@ -282,7 +279,7 @@ object EvaluatePython { } else { // it will be memorized by Pickler to save some bytes pickler.save(this) - val row = obj.asInstanceOf[GenericInternalRowWithSchema] + val row = obj.asInstanceOf[GenericRowWithSchema] // schema should always be same object for memoization pickler.save(row.schema) out.write(Opcodes.TUPLE1) From 289373b28cd2546165187de2e6a9185a1257b1e7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 30 Jan 2016 00:20:28 -0800 Subject: [PATCH 027/107] [SPARK-6363][BUILD] Make Scala 2.11 the default Scala version This patch changes Spark's build to make Scala 2.11 the default Scala version. To be clear, this does not mean that Spark will stop supporting Scala 2.10: users will still be able to compile Spark for Scala 2.10 by following the instructions on the "Building Spark" page; however, it does mean that Scala 2.11 will be the default Scala version used by our CI builds (including pull request builds). The Scala 2.11 compiler is faster than 2.10, so I think we'll be able to look forward to a slight speedup in our CI builds (it looks like it's about 2X faster for the Maven compile-only builds, for instance). After this patch is merged, I'll update Jenkins to add new compile-only jobs to ensure that Scala 2.10 compilation doesn't break. Author: Josh Rosen Closes #10608 from JoshRosen/SPARK-6363. --- assembly/pom.xml | 4 +-- common/sketch/pom.xml | 4 +-- core/pom.xml | 4 +-- dev/create-release/release-build.sh | 14 ++++----- dev/deps/spark-deps-hadoop-2.2 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.3 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.4 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.6 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.7 | 31 +++++++++---------- docker-integration-tests/pom.xml | 4 +-- docs/_plugins/copy_api_dirs.rb | 2 +- docs/building-spark.md | 10 +++--- examples/pom.xml | 4 +-- external/akka/pom.xml | 4 +-- external/flume-assembly/pom.xml | 4 +-- external/flume-sink/pom.xml | 4 +-- external/flume/pom.xml | 4 +-- external/kafka-assembly/pom.xml | 4 +-- external/kafka/pom.xml | 4 +-- external/mqtt-assembly/pom.xml | 4 +-- external/mqtt/pom.xml | 4 +-- external/twitter/pom.xml | 4 +-- external/zeromq/pom.xml | 4 +-- extras/java8-tests/pom.xml | 4 +-- extras/kinesis-asl-assembly/pom.xml | 4 +-- extras/kinesis-asl/pom.xml | 4 +-- extras/spark-ganglia-lgpl/pom.xml | 4 +-- graphx/pom.xml | 4 +-- launcher/pom.xml | 4 +-- mllib/pom.xml | 4 +-- network/common/pom.xml | 4 +-- network/shuffle/pom.xml | 4 +-- network/yarn/pom.xml | 4 +-- pom.xml | 8 ++--- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 6 ++++ project/SparkBuild.scala | 12 +++---- repl/pom.xml | 8 ++--- .../scala/org/apache/spark/repl/Main.scala | 9 +++++- .../org/apache/spark/repl/ReplSuite.scala | 7 +---- sql/catalyst/pom.xml | 13 ++------ sql/core/pom.xml | 6 ++-- sql/hive-thriftserver/pom.xml | 4 +-- sql/hive/pom.xml | 4 +-- streaming/pom.xml | 4 +-- tags/pom.xml | 4 +-- tools/pom.xml | 4 +-- unsafe/pom.xml | 4 +-- yarn/pom.xml | 4 +-- 49 files changed, 186 insertions(+), 194 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 6c79f9189787d..477d4931c3a88 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-assembly_2.10 + spark-assembly_2.11 Spark Project Assembly http://spark.apache.org/ pom diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2cafe8c548f5f..442043cb51164 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sketch_2.10 + spark-sketch_2.11 jar Spark Project Sketch http://spark.apache.org/ diff --git a/core/pom.xml b/core/pom.xml index 0ab170e028ab4..be40d9936afd7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-core_2.10 + spark-core_2.11 core diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 00bf81120df65..2fd7fcc39ea28 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -134,9 +134,9 @@ if [[ "$1" == "package" ]]; then cd spark-$SPARK_VERSION-bin-$NAME - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-scala-version.sh 2.11 + # TODO There should probably be a flag to make-distribution to allow 2.10 support + if [[ $FLAGS == *scala-2.10* ]]; then + ./dev/change-scala-version.sh 2.10 fi export ZINC_PORT=$ZINC_PORT @@ -228,8 +228,8 @@ if [[ "$1" == "publish-snapshot" ]]; then $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver deploy - ./dev/change-scala-version.sh 2.11 - $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.10 --settings $tmp_settings \ -DskipTests $PUBLISH_PROFILES clean deploy # Clean-up Zinc nailgun process @@ -266,9 +266,9 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver clean install - ./dev/change-scala-version.sh 2.11 + ./dev/change-scala-version.sh 2.10 - $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ -DskipTests $PUBLISH_PROFILES clean install # Clean-up Zinc nailgun process diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 4d9937c5cbc34..3a14499d9b4d9 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -14,13 +14,13 @@ avro-ipc-1.7.7-tests.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -86,10 +86,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar javax.servlet-3.1.jar @@ -111,15 +110,14 @@ jets3t-0.7.1.jar jettison-1.1.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -158,19 +156,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index fd659ee20df1a..615836b3d3b77 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -16,13 +16,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -81,10 +81,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -102,15 +101,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -149,19 +147,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index afae3deb9ada2..f275226f1d088 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -16,13 +16,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -81,10 +81,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -103,15 +102,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -150,19 +148,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 5a6460136a3a0..21432a16e3659 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -20,13 +20,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -87,10 +87,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -109,15 +108,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -156,19 +154,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 70083e7f3d16a..20e09cd002635 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -20,13 +20,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -87,10 +87,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -109,15 +108,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar @@ -157,19 +155,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index 78b638ecfa638..833ca29cd8218 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml - spark-docker-integration-tests_2.10 + spark-docker-integration-tests_2.11 jar Spark Project Docker Integration Tests http://spark.apache.org/ diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 174c202e37918..f926d67e6beaf 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -37,7 +37,7 @@ # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.10/unidoc" + source = "../target/scala-2.11/unidoc" dest = "api/scala" puts "Making directory " + dest diff --git a/docs/building-spark.md b/docs/building-spark.md index e1abcf1be501d..975e1b295c8ae 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -114,13 +114,11 @@ By default Spark will build with Hive 0.13.1 bindings. mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package {% endhighlight %} -# Building for Scala 2.11 -To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: +# Building for Scala 2.10 +To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` property: - ./dev/change-scala-version.sh 2.11 - mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package - -Spark does not yet support its JDBC component for Scala 2.11. + ./dev/change-scala-version.sh 2.10 + mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package # Spark Tests in Maven diff --git a/examples/pom.xml b/examples/pom.xml index 9437cee2abfdf..82baa9085b4f9 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-examples_2.10 + spark-examples_2.11 examples diff --git a/external/akka/pom.xml b/external/akka/pom.xml index 06c8e8aaabd8c..bbe644e3b32b3 100644 --- a/external/akka/pom.xml +++ b/external/akka/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-akka_2.10 + spark-streaming-akka_2.11 streaming-akka diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index b2c377fe4cc9b..ac15b93c048da 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume-assembly_2.10 + spark-streaming-flume-assembly_2.11 jar Spark Project External Flume Assembly http://spark.apache.org/ diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 4b6485ee0a71a..e4effe158c826 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume-sink_2.10 + spark-streaming-flume-sink_2.11 streaming-flume-sink diff --git a/external/flume/pom.xml b/external/flume/pom.xml index a79656c6f7d96..d650dd034d636 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume_2.10 + spark-streaming-flume_2.11 streaming-flume diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 0c466b3c4ac37..62818f5e8f434 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka-assembly_2.10 + spark-streaming-kafka-assembly_2.11 jar Spark Project External Kafka Assembly http://spark.apache.org/ diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 5180ab6dbafbd..68d52e9339b3d 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka_2.10 + spark-streaming-kafka_2.11 streaming-kafka diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index c4a1ae26ea699..ac2a3f65ed2f5 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-mqtt-assembly_2.10 + spark-streaming-mqtt-assembly_2.11 jar Spark Project External MQTT Assembly http://spark.apache.org/ diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index d3a2bf5825b08..d0d968782c7f1 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-mqtt_2.10 + spark-streaming-mqtt_2.11 streaming-mqtt diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 7b628b09ea6a5..5d4053afcbba7 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-twitter_2.10 + spark-streaming-twitter_2.11 streaming-twitter diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7781aaeed9e0c..f16bc0f319744 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-zeromq_2.10 + spark-streaming-zeromq_2.11 streaming-zeromq diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4dfe3b654df1a..0ad9c5303a36a 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - java8-tests_2.10 + java8-tests_2.11 pom Spark Project Java8 Tests POM diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 601080c2e6fbd..d1c38c7ca5d69 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kinesis-asl-assembly_2.10 + spark-streaming-kinesis-asl-assembly_2.11 jar Spark Project Kinesis Assembly http://spark.apache.org/ diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 20e2c5e0ffbee..935155eb5d362 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -19,14 +19,14 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kinesis-asl_2.10 + spark-streaming-kinesis-asl_2.11 jar Spark Kinesis Integration diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index b046a10a04d5b..bfb92791de3d8 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -19,14 +19,14 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-ganglia-lgpl_2.10 + spark-ganglia-lgpl_2.11 jar Spark Ganglia Integration diff --git a/graphx/pom.xml b/graphx/pom.xml index 388a0ef06a2b0..1813f383cdcba 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-graphx_2.10 + spark-graphx_2.11 graphx diff --git a/launcher/pom.xml b/launcher/pom.xml index 135866cea2e74..ef731948826ef 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-launcher_2.10 + spark-launcher_2.11 jar Spark Project Launcher http://spark.apache.org/ diff --git a/mllib/pom.xml b/mllib/pom.xml index 42af2b8b3e411..816f3f6830382 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-mllib_2.10 + spark-mllib_2.11 mllib diff --git a/network/common/pom.xml b/network/common/pom.xml index eda2b7307088f..bd507c2cb6c4b 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-common_2.10 + spark-network-common_2.11 jar Spark Project Networking http://spark.apache.org/ diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index f9aa7e2dd1f43..810ec10ca05b3 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_2.11 jar Spark Project Shuffle Streaming Service http://spark.apache.org/ diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a19cbb04b18c6..a28785b16e1e6 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-yarn_2.10 + spark-network-yarn_2.11 jar Spark Project YARN Shuffle Service http://spark.apache.org/ diff --git a/pom.xml b/pom.xml index fb7750602c425..d0387aca66d0d 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ 14 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT pom Spark Project Parent POM @@ -165,7 +165,7 @@ 3.2.2 2.10.5 - 2.10 + 2.11 ${scala.version} org.scala-lang 1.9.13 @@ -2456,7 +2456,7 @@ scala-2.10 - !scala-2.11 + scala-2.10 2.10.5 @@ -2488,7 +2488,7 @@ scala-2.11 - scala-2.11 + !scala-2.10 2.11.7 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 41856443af49b..4adf64a5a0d86 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -95,7 +95,7 @@ object MimaBuild { // because spark-streaming-mqtt(1.6.0) depends on it. // Remove the setting on updating previousSparkVersion. val previousSparkVersion = "1.6.0" - val fullId = "spark-" + projectRef.project + "_2.10" + val fullId = "spark-" + projectRef.project + "_2.11" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a3ae4d2b730ff..3748e07f88aad 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -220,6 +220,12 @@ object MimaExcludes { // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") + ) ++ Seq( + // SPARK-6363 Make Scala 2.11 the default Scala version + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") ) case v if v.startsWith("1.6") => Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4224a65a822b8..550b5bad8a46a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -119,11 +119,11 @@ object SparkBuild extends PomBuild { v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } - if (System.getProperty("scala-2.11") == "") { - // To activate scala-2.11 profile, replace empty property value to non-empty value + if (System.getProperty("scala-2.10") == "") { + // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.11", "true") + System.setProperty("scala-2.10", "true") } profiles } @@ -382,7 +382,7 @@ object OldDeps { lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings) def versionArtifact(id: String): Option[sbt.ModuleID] = { - val fullId = id + "_2.10" + val fullId = id + "_2.11" Some("org.apache.spark" % fullId % "1.2.0") } @@ -390,7 +390,7 @@ object OldDeps { name := "old-deps", scalaVersion := "2.10.5", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", - "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming-flume", "spark-streaming-twitter", "spark-streaming", "spark-mllib", "spark-graphx", "spark-core").map(versionArtifact(_).get intransitive()) ) @@ -704,7 +704,7 @@ object Java8TestSettings { lazy val settings = Seq( javacJVMVersion := "1.8", // Targeting Java 8 bytecode is only supported in Scala 2.11.4 and higher: - scalacJVMVersion := (if (System.getProperty("scala-2.11") == "true") "1.8" else "1.7") + scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8") ) } diff --git a/repl/pom.xml b/repl/pom.xml index efc3dd452e329..0f396c9b809bd 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-repl_2.10 + spark-repl_2.11 jar Spark Project REPL http://spark.apache.org/ @@ -159,7 +159,7 @@ scala-2.10 - !scala-2.11 + scala-2.10 @@ -173,7 +173,7 @@ scala-2.11 - scala-2.11 + !scala-2.10 scala-2.11/src/main/scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index bb3081d12938e..07ba28bb07545 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -33,7 +33,8 @@ object Main extends Logging { var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ - var interp = new SparkILoop // this is a public var because tests reset it. + // this is a public var because tests reset it. + var interp: SparkILoop = _ private var hasErrors = false @@ -43,6 +44,12 @@ object Main extends Logging { } def main(args: Array[String]) { + doMain(args, new SparkILoop) + } + + // Visible for testing + private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = { + interp = _interp val interpArguments = List( "-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 63f3688c9e612..b9ed79da421a6 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -50,12 +50,7 @@ class ReplSuite extends SparkFunSuite { System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) System.setProperty("spark.master", master) - val interp = { - new SparkILoop(in, new PrintWriter(out)) - } - org.apache.spark.repl.Main.interp = interp - Main.main(Array("-classpath", classpath)) // call main - org.apache.spark.repl.Main.interp = null + Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out))) if (oldExecutorClasspath != null) { System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 76ca3f3bb1bfa..c2ad9b99f3ac9 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-catalyst_2.10 + spark-catalyst_2.11 jar Spark Project Catalyst http://spark.apache.org/ @@ -127,13 +127,4 @@ - - - - scala-2.10 - - !scala-2.11 - - - diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 4bb55f6b7f739..89e01fc01596e 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sql_2.10 + spark-sql_2.11 jar Spark Project SQL http://spark.apache.org/ @@ -44,7 +44,7 @@ org.apache.spark - spark-sketch_2.10 + spark-sketch_2.11 ${project.version} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 435e565f63458..c8d17bd468582 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-hive-thriftserver_2.10 + spark-hive-thriftserver_2.11 jar Spark Project Hive Thrift Server http://spark.apache.org/ diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index cd0c2aeb93a9f..14cf9acf09d5b 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-hive_2.10 + spark-hive_2.11 jar Spark Project Hive http://spark.apache.org/ diff --git a/streaming/pom.xml b/streaming/pom.xml index 39cbd0d00f951..7d409c5d3b076 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-streaming_2.10 + spark-streaming_2.11 streaming diff --git a/tags/pom.xml b/tags/pom.xml index 9e4610dae7a65..3e8e6f6182875 100644 --- a/tags/pom.xml +++ b/tags/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-test-tags_2.10 + spark-test-tags_2.11 jar Spark Project Test Tags http://spark.apache.org/ diff --git a/tools/pom.xml b/tools/pom.xml index 30cbb6a5a59c7..b3a5ae2771241 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-tools_2.10 + spark-tools_2.11 tools diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 21fef3415adce..75fea556eeae1 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-unsafe_2.10 + spark-unsafe_2.11 jar Spark Project Unsafe http://spark.apache.org/ diff --git a/yarn/pom.xml b/yarn/pom.xml index a8c122fd40a1f..328bb6678db99 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-yarn_2.10 + spark-yarn_2.11 jar Spark Project YARN From de283719980ae78b740e507e4d70c7ebbf6c5f74 Mon Sep 17 00:00:00 2001 From: wangyang Date: Sat, 30 Jan 2016 15:20:57 -0800 Subject: [PATCH 028/107] [SPARK-13100][SQL] improving the performance of stringToDate method in DateTimeUtils.scala In jdk1.7 TimeZone.getTimeZone() is synchronized, so use an instance variable to hold an GMT TimeZone object instead of instantiate it every time. Author: wangyang Closes #10994 from wangyang1992/datetimeUtil. --- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f18c052b68e37..a159bc6a61415 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -55,6 +55,7 @@ object DateTimeUtils { // this is year -17999, calculation: 50 * daysIn400Year final val YearZero = -17999 final val toYearZero = to2001 + 7304850 + final val TimeZoneGMT = TimeZone.getTimeZone("GMT") @transient lazy val defaultTimeZone = TimeZone.getDefault @@ -407,7 +408,7 @@ object DateTimeUtils { segments(2) < 1 || segments(2) > 31) { return None } - val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + val c = Calendar.getInstance(TimeZoneGMT) c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) c.set(Calendar.MILLISECOND, 0) Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) From a1303de0a0e9d0c80327977abf52a79e2aa95e1f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 30 Jan 2016 23:02:49 -0800 Subject: [PATCH 029/107] [SPARK-13070][SQL] Better error message when Parquet schema merging fails Make sure we throw better error messages when Parquet schema merging fails. Author: Cheng Lian Author: Liang-Chi Hsieh Closes #10979 from viirya/schema-merging-failure-message. --- .../apache/spark/sql/types/StructType.scala | 6 ++-- .../datasources/parquet/ParquetRelation.scala | 33 ++++++++++++++++--- .../parquet/ParquetFilterSuite.scala | 15 +++++++++ .../parquet/ParquetSchemaSuite.scala | 30 +++++++++++++++++ 4 files changed, 77 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index da0c92864e9bd..c9e7e7fe633b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -424,13 +424,13 @@ object StructType extends AbstractDataType { if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { DecimalType(leftPrecision, leftScale) } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") } else if (leftPrecision != rightPrecision) { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"precision $leftPrecision and $rightPrecision") } else { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"scala $leftScale and $rightScale") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index f87590095d344..1e686d41f41db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -800,12 +800,37 @@ private[sql] object ParquetRelation extends Logging { assumeInt96IsTimestamp = assumeInt96IsTimestamp, writeLegacyParquetFormat = writeLegacyParquetFormat) - footers.map { footer => - ParquetRelation.readSchemaFromFooter(footer, converter) - }.reduceLeftOption(_ merge _).iterator + if (footers.isEmpty) { + Iterator.empty + } else { + var mergedSchema = ParquetRelation.readSchemaFromFooter(footers.head, converter) + footers.tail.foreach { footer => + val schema = ParquetRelation.readSchemaFromFooter(footer, converter) + try { + mergedSchema = mergedSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema of file ${footer.getFile}:\n${schema.treeString}", cause) + } + } + Iterator.single(mergedSchema) + } }.collect() - partiallyMergedSchemas.reduceLeftOption(_ merge _) + if (partiallyMergedSchemas.isEmpty) { + None + } else { + var finalSchema = partiallyMergedSchemas.head + partiallyMergedSchemas.tail.foreach { schema => + try { + finalSchema = finalSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Some(finalSchema) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1796b3af0e37a..3ded32c450541 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -421,6 +421,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // We will remove the temporary metadata when writing Parquet file. val forPathSix = sqlContext.read.parquet(pathSix).schema assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + // sanity test: make sure optional metadata field is not wrongly set. + val pathSeven = s"${dir.getCanonicalPath}/table7" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven) + val pathEight = s"${dir.getCanonicalPath}/table8" + (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) + + val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") + checkAnswer( + df2, + Row(1, "1")) + + // The fields "a" and "b" exist in both two Parquet files. No metadata is set. + assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField)) + assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 60fa81b1ab819..d860651d421f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -449,6 +450,35 @@ class ParquetSchemaSuite extends ParquetSchemaTest { }.getMessage.contains("detected conflicting schemas")) } + test("schema merging failure error message") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.parquet(s"$path/p=1") + sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + + val message = intercept[SparkException] { + sqlContext.read.option("mergeSchema", "true").parquet(path).schema + }.getMessage + + assert(message.contains("Failed merging schema of file")) + } + + // test for second merging (after read Parquet schema in parallel done) + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.parquet(s"$path/p=1") + sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + + sqlContext.sparkContext.conf.set("spark.default.parallelism", "20") + + val message = intercept[SparkException] { + sqlContext.read.option("mergeSchema", "true").parquet(path).schema + }.getMessage + + assert(message.contains("Failed merging schema:")) + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= From 0e6d92d042b0a2920d8df5959d5913ba0166a678 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 30 Jan 2016 23:05:29 -0800 Subject: [PATCH 030/107] [SPARK-12689][SQL] Migrate DDL parsing to the newly absorbed parser JIRA: https://issues.apache.org/jira/browse/SPARK-12689 DDLParser processes three commands: createTable, describeTable and refreshTable. This patch migrates the three commands to newly absorbed parser. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10723 from viirya/migrate-ddl-describe. --- project/MimaExcludes.scala | 5 + .../sql/catalyst/parser/ExpressionParser.g | 14 ++ .../spark/sql/catalyst/parser/SparkSqlLexer.g | 4 +- .../sql/catalyst/parser/SparkSqlParser.g | 80 +++++++- .../spark/sql/catalyst/CatalystQl.scala | 23 ++- .../org/apache/spark/sql/SQLContext.scala | 5 +- .../apache/spark/sql/execution/SparkQl.scala | 101 ++++++++- .../sql/execution/datasources/DDLParser.scala | 193 ------------------ .../spark/sql/execution/datasources/ddl.scala | 5 - .../sources/CreateTableAsSelectSuite.scala | 7 +- 10 files changed, 208 insertions(+), 229 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3748e07f88aad..8b1a7303fc5b2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -200,6 +200,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") + ) ++ Seq( + // SPARK-12689 Migrate DDL parsing to the newly absorbed parser + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser") ) ++ Seq( // SPARK-7799 Add "streaming-akka" project ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index 0555a6ba83cbb..c162c1a0c5789 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -493,6 +493,16 @@ descFuncNames | functionIdentifier ; +//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. +looseIdentifier + : + Identifier + | looseNonReserved -> Identifier[$looseNonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + identifier : Identifier @@ -516,6 +526,10 @@ principalIdentifier | QuotedIdentifier ; +looseNonReserved + : nonReserved | KW_FROM | KW_TO + ; + //The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved //Non reserved keywords are basically the keywords that can be used as identifiers. //All the KW_* are automatically not only keywords, but also reserved keywords. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index 4374cd7ef7200..e930caa291d4f 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -324,6 +324,8 @@ KW_ISOLATION: 'ISOLATION'; KW_LEVEL: 'LEVEL'; KW_SNAPSHOT: 'SNAPSHOT'; KW_AUTOCOMMIT: 'AUTOCOMMIT'; +KW_REFRESH: 'REFRESH'; +KW_OPTIONS: 'OPTIONS'; KW_WEEK: 'WEEK'|'WEEKS'; KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; @@ -470,7 +472,7 @@ Identifier fragment QuotedIdentifier : - '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } + '`' ( '``' | ~('`') )* '`' { setText(getText().replaceAll("``", "`")); } ; WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 35bef00351d72..6591f6b0f56ce 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -142,6 +142,7 @@ TOK_UNIONTYPE; TOK_COLTYPELIST; TOK_CREATEDATABASE; TOK_CREATETABLE; +TOK_CREATETABLEUSING; TOK_TRUNCATETABLE; TOK_CREATEINDEX; TOK_CREATEINDEX_INDEXTBLNAME; @@ -371,6 +372,10 @@ TOK_TXN_READ_WRITE; TOK_COMMIT; TOK_ROLLBACK; TOK_SET_AUTOCOMMIT; +TOK_REFRESHTABLE; +TOK_TABLEPROVIDER; +TOK_TABLEOPTIONS; +TOK_TABLEOPTION; TOK_CACHETABLE; TOK_UNCACHETABLE; TOK_CLEARCACHE; @@ -660,6 +665,12 @@ import java.util.HashMap; } private char [] excludedCharForColumnName = {'.', ':'}; private boolean containExcludedCharForCreateTableColumnName(String input) { + if (input.length() > 0) { + if (input.charAt(0) == '`' && input.charAt(input.length() - 1) == '`') { + // When column name is backquoted, we don't care about excluded chars. + return false; + } + } for(char c : excludedCharForColumnName) { if(input.indexOf(c)>-1) { return true; @@ -781,6 +792,7 @@ ddlStatement | truncateTableStatement | alterStatement | descStatement + | refreshStatement | showStatement | metastoreCheck | createViewStatement @@ -907,12 +919,31 @@ createTableStatement @init { pushMsg("create table statement", state); } @after { popMsg(state); } : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName - ( like=KW_LIKE likeName=tableName + ( + like=KW_LIKE likeName=tableName tableRowFormat? tableFileFormat? tableLocation? tablePropertiesPrefixed? + -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + ^(TOK_LIKETABLE $likeName?) + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + ) + | + tableProvider + tableOpts? + (KW_AS selectStatementWithCTE)? + -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? + tableProvider + tableOpts? + selectStatementWithCTE? + ) | (LPAREN columnNameTypeList RPAREN)? + (p=tableProvider?) + tableOpts? tableComment? tablePartition? tableBuckets? @@ -922,8 +953,15 @@ createTableStatement tableLocation? tablePropertiesPrefixed? (KW_AS selectStatementWithCTE)? - ) - -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + -> {p != null}? + ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? + columnNameTypeList? + $p + tableOpts? + selectStatementWithCTE? + ) + -> + ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? ^(TOK_LIKETABLE $likeName?) columnNameTypeList? tableComment? @@ -935,7 +973,8 @@ createTableStatement tableLocation? tablePropertiesPrefixed? selectStatementWithCTE? - ) + ) + ) ; truncateTableStatement @@ -1379,6 +1418,13 @@ tabPartColTypeExpr : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) ; +refreshStatement +@init { pushMsg("refresh statement", state); } +@after { popMsg(state); } + : + KW_REFRESH KW_TABLE tableName -> ^(TOK_REFRESHTABLE tableName) + ; + descStatement @init { pushMsg("describe statement", state); } @after { popMsg(state); } @@ -1774,6 +1820,30 @@ showStmtIdentifier | StringLiteral ; +tableProvider +@init { pushMsg("table's provider", state); } +@after { popMsg(state); } + : + KW_USING Identifier (DOT Identifier)* + -> ^(TOK_TABLEPROVIDER Identifier+) + ; + +optionKeyValue +@init { pushMsg("table's option specification", state); } +@after { popMsg(state); } + : + (looseIdentifier (DOT looseIdentifier)*) StringLiteral + -> ^(TOK_TABLEOPTION looseIdentifier+ StringLiteral) + ; + +tableOpts +@init { pushMsg("table's options", state); } +@after { popMsg(state); } + : + KW_OPTIONS LPAREN optionKeyValue (COMMA optionKeyValue)* RPAREN + -> ^(TOK_TABLEOPTIONS optionKeyValue+) + ; + tableComment @init { pushMsg("table's comment", state); } @after { popMsg(state); } @@ -2132,7 +2202,7 @@ structType mapType @init { pushMsg("map type", state); } @after { popMsg(state); } - : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN + : KW_MAP LESSTHAN left=type COMMA right=type GREATERTHAN -> ^(TOK_MAP $left $right) ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 536c292ab7f34..7ce2407913ade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -140,6 +140,7 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends case Token("TOK_BOOLEAN", Nil) => BooleanType case Token("TOK_STRING", Nil) => StringType case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType + case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType case Token("TOK_FLOAT", Nil) => FloatType case Token("TOK_DOUBLE", Nil) => DoubleType case Token("TOK_DATE", Nil) => DateType @@ -156,9 +157,10 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends protected def nodeToStructField(node: ASTNode): StructField = node match { case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: _ /* comment */:: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) + StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) => + val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build() + StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta) case _ => noParseRule("StructField", node) } @@ -222,15 +224,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Nil => ShowFunctions(None, None) case Token(name, Nil) :: Nil => - ShowFunctions(None, Some(unquoteString(name))) + ShowFunctions(None, Some(unquoteString(cleanIdentifier(name)))) case Token(db, Nil) :: Token(name, Nil) :: Nil => - ShowFunctions(Some(unquoteString(db)), Some(unquoteString(name))) + ShowFunctions(Some(unquoteString(cleanIdentifier(db))), + Some(unquoteString(cleanIdentifier(name)))) case _ => noParseRule("SHOW FUNCTIONS", node) } case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) => - DescribeFunction(functionName, isExtended.nonEmpty) + DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty) case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = @@ -611,7 +614,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C noParseRule("Select", node) } - protected val escapedIdentifier = "`([^`]+)`".r + protected val escapedIdentifier = "`(.+)`".r protected val doubleQuotedString = "\"([^\"]+)\"".r protected val singleQuotedString = "'([^']+)'".r @@ -655,7 +658,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C nodeToExpr(qualifier) match { case UnresolvedAttribute(nameParts) => UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(attr)) + case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr))) } /* Stars (*) */ @@ -663,7 +666,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty => - UnresolvedStar(Some(target.map(_.text))) + UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text)))) /* Aggregate Functions */ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => @@ -971,7 +974,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node - val alias = getClause("TOK_TABALIAS", clauses).children.head.text + val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text) val generator = clauses.head match { case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index be28df3a51557..ef993c3edae37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -206,10 +206,7 @@ class SQLContext private[sql]( @transient protected[sql] val sqlParser: ParserInterface = new SparkQl(conf) - @transient - protected[sql] val ddlParser: DDLParser = new DDLParser(sqlParser) - - protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) + protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index a5bd8ee42dec9..4174e27e9c8b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -16,11 +16,14 @@ */ package org.apache.spark.sql.execution +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.types.StructType private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { /** Check if a command should not be explained. */ @@ -55,6 +58,86 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + case Token("TOK_REFRESHTABLE", nameParts :: Nil) => + val tableIdent = extractTableIdent(nameParts) + RefreshTable(tableIdent) + + case Token("TOK_CREATETABLEUSING", createTableArgs) => + val Seq( + temp, + allowExisting, + Some(tabName), + tableCols, + Some(Token("TOK_TABLEPROVIDER", providerNameParts)), + tableOpts, + tableAs) = getClauses(Seq( + "TEMPORARY", + "TOK_IFNOTEXISTS", + "TOK_TABNAME", "TOK_TABCOLLIST", + "TOK_TABLEPROVIDER", + "TOK_TABLEOPTIONS", + "TOK_QUERY"), createTableArgs) + + val tableIdent: TableIdentifier = extractTableIdent(tabName) + + val columns = tableCols.map { + case Token("TOK_TABCOLLIST", fields) => StructType(fields.map(nodeToStructField)) + } + + val provider = providerNameParts.map { + case Token(name, Nil) => name + }.mkString(".") + + val options: Map[String, String] = tableOpts.toSeq.flatMap { + case Token("TOK_TABLEOPTIONS", options) => + options.map { + case Token("TOK_TABLEOPTION", keysAndValue) => + val key = keysAndValue.init.map(_.text).mkString(".") + val value = unquoteString(keysAndValue.last.text) + (key, value) + } + }.toMap + + val asClause = tableAs.map(nodeToPlan(_)) + + if (temp.isDefined && allowExisting.isDefined) { + throw new AnalysisException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + if (asClause.isDefined) { + if (columns.isDefined) { + throw new AnalysisException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + CreateTableUsingAsSelect(tableIdent, + provider, + temp.isDefined, + Array.empty[String], + bucketSpec = None, + mode, + options, + asClause.get) + } else { + CreateTableUsing( + tableIdent, + columns, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } + case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) => SetDatabaseCommand(cleanIdentifier(database)) @@ -68,26 +151,30 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly nodeToDescribeFallback(node) } else { tableType match { - case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts) :: Nil) => nameParts match { - case Token(".", dbName :: tableName :: Nil) => + case Token(dbName, Nil) :: Token(tableName, Nil) :: Nil => // It is describing a table with the format like "describe db.table". // TODO: Actually, a user may mean tableName.columnName. Need to resolve this // issue. - val tableIdent = extractTableIdent(nameParts) + val tableIdent = TableIdentifier( + cleanIdentifier(tableName), Some(cleanIdentifier(dbName))) datasources.DescribeCommand( UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) - case Token(".", dbName :: tableName :: colName :: Nil) => + case Token(dbName, Nil) :: Token(tableName, Nil) :: Token(colName, Nil) :: Nil => // It is describing a column with the format like "describe db.table column". nodeToDescribeFallback(node) - case tableName => + case tableName :: Nil => // It is describing a table with the format like "describe table". datasources.DescribeCommand( - UnresolvedRelation(TableIdentifier(tableName.text), None), + UnresolvedRelation(TableIdentifier(cleanIdentifier(tableName.text)), None), isExtended = extended.isDefined) + case _ => + nodeToDescribeFallback(node) } // All other cases. - case _ => nodeToDescribeFallback(node) + case _ => + nodeToDescribeFallback(node) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala deleted file mode 100644 index f4766b037027d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ /dev/null @@ -1,193 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.datasources - -import scala.language.implicitConversions -import scala.util.matching.Regex - -import org.apache.spark.Logging -import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserInterface, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.types._ - -/** - * A parser for foreign DDL commands. - */ -class DDLParser(fallback: => ParserInterface) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) - - override def parseTableIdentifier(sql: String): TableIdentifier = { - fallback.parseTableIdentifier(sql) - } - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parsePlan(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => fallback.parsePlan(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable(intField int, stringField string...) - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = { - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = fallback.parsePlan(query.get) - CreateTableUsingAsSelect(tableIdent, - provider, - temp.isDefined, - Array.empty[String], - bucketSpec = None, - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableIdent, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - } - - // This is the same as tableIdentifier in SqlParser. - protected lazy val tableIdentifier: Parser[TableIdentifier] = - (ident <~ ".").? ~ ident ^^ { - case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { - case e ~ tableIdent => - DescribeCommand(UnresolvedRelation(tableIdent, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> tableIdentifier ^^ { - case tableIndet => - RefreshTable(tableIndet) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index c3603936dfd2e..1554209be9891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -169,8 +169,3 @@ class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] override def -(key: String): Map[String, String] = baseMap - key.toLowerCase } - -/** - * The exception thrown from the DDL parser. - */ -class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 6fc9febe49707..cb88a1c83c999 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -22,7 +22,6 @@ import java.io.{File, IOException} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -105,7 +104,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - val message = intercept[DDLException]{ + val message = intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable @@ -156,7 +155,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { - val message = intercept[DDLException]{ + val message = intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable @@ -173,7 +172,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("a CTAS statement with column definitions is not allowed") { - intercept[DDLException]{ + intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) From 5a8b978fabb60aa178274f86432c63680c8b351a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 31 Jan 2016 13:56:13 -0800 Subject: [PATCH 031/107] [SPARK-13049] Add First/last with ignore nulls to functions.scala This PR adds the ability to specify the ```ignoreNulls``` option to the functions dsl, e.g: ```df.select($"id", last($"value", ignoreNulls = true).over(Window.partitionBy($"id").orderBy($"other"))``` This PR is some where between a bug fix (see the JIRA) and a new feature. I am not sure if we should backport to 1.6. cc yhuai Author: Herman van Hovell Closes #10957 from hvanhovell/SPARK-13049. --- python/pyspark/sql/functions.py | 26 +++- python/pyspark/sql/tests.py | 10 ++ .../org/apache/spark/sql/functions.scala | 118 ++++++++++++++---- .../spark/sql/DataFrameWindowSuite.scala | 32 +++++ 4 files changed, 157 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 719eca8f5559e..0d5708526701e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -81,8 +81,6 @@ def _(): 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', 'count': 'Aggregate function: returns the number of items in a group.', 'sum': 'Aggregate function: returns the sum of all values in the expression.', 'avg': 'Aggregate function: returns the average of the values in a group.', @@ -278,6 +276,18 @@ def countDistinct(col, *cols): return Column(jc) +@since(1.3) +def first(col, ignorenulls=False): + """Aggregate function: returns the first value in a group. + + The function by default returns the first values it sees. It will return the first non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) + return Column(jc) + + @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. @@ -310,6 +320,18 @@ def isnull(col): return Column(sc._jvm.functions.isnull(_to_java_column(col))) +@since(1.3) +def last(col, ignorenulls=False): + """Aggregate function: returns the last value in a group. + + The function by default returns the last values it sees. It will return the last non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) + return Column(jc) + + @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 410efbafe0792..e30aa0a796924 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -641,6 +641,16 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_first_last_ignorenulls(self): + from pyspark.sql import functions + df = self.sqlCtx.range(0, 100) + df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) + df3 = df2.select(functions.first(df2.id, False).alias('a'), + functions.first(df2.id, True).alias('b'), + functions.last(df2.id, False).alias('c'), + functions.last(df2.id, True).alias('d')) + self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a27466176a20..b970eee4e31a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -349,19 +349,51 @@ object functions extends LegacyFunctions { } /** - * Aggregate function: returns the first value in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def first(e: Column): Column = withAggregateFunction { new First(e.expr) } - - /** - * Aggregate function: returns the first value of a column in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { + new First(e.expr, Literal(ignoreNulls)) + } + + /** + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def first(columnName: String, ignoreNulls: Boolean): Column = { + first(Column(columnName), ignoreNulls) + } + + /** + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def first(e: Column): Column = first(e, ignoreNulls = false) + + /** + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def first(columnName: String): Column = first(Column(columnName)) /** @@ -381,20 +413,52 @@ object functions extends LegacyFunctions { def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) /** - * Aggregate function: returns the last value in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def last(e: Column): Column = withAggregateFunction { new Last(e.expr) } - - /** - * Aggregate function: returns the last value of the column in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def last(columnName: String): Column = last(Column(columnName)) + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { + new Last(e.expr, Literal(ignoreNulls)) + } + + /** + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def last(columnName: String, ignoreNulls: Boolean): Column = { + last(Column(columnName), ignoreNulls) + } + + /** + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def last(e: Column): Column = last(e, ignoreNulls = false) + + /** + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false) /** * Aggregate function: returns the maximum value of the expression in a group. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 09a56f6f3ae28..d38842c3c0cf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -312,4 +312,36 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row("b", 3, null, null), Row("b", 2, null, null))) } + + test("last/first with ignoreNulls") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, nullStr), + ("b", 1, nullStr), + ("b", 2, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order") + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, null, null, null, null, null, null), + Row("a", 1, null, null, "x", "x", "x", "x"), + Row("a", 2, null, null, "x", "y", "y", "y"), + Row("a", 3, null, null, "x", "z", "z", "z"), + Row("a", 4, null, null, "x", null, null, "z"), + Row("b", 1, null, null, null, null, null, null), + Row("b", 2, null, null, null, null, null, null))) + } } From c1da4d421ab78772ffa52ad46e5bdfb4e5268f47 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 31 Jan 2016 22:43:03 -0800 Subject: [PATCH 032/107] [SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and ternary expression The current implementation is sub-optimal: * If an expression is always nullable, e.g. `Unhex`, we can still remove null check for children if they are not nullable. * If an expression has some non-nullable children, we can still remove null check for these children and keep null check for others. This PR improves this by making the null check elimination more fine-grained. Author: Wenchen Fan Closes #10987 from cloud-fan/null-check. --- .../sql/catalyst/expressions/Expression.scala | 104 ++++++++++-------- .../expressions/codegen/CodeGenerator.scala | 32 +++++- .../spark/sql/catalyst/expressions/misc.scala | 16 +-- 3 files changed, 85 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index db17ba7c84ffc..353fb92581d3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * As an example, the following does a boolean inversion (i.e. NOT). * {{{ @@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * @param f function that accepts the non-null evaluation result name of child and returns Java * code to compute the output. @@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: String => String): String = { - val eval = child.gen(ctx) + val childGen = child.gen(ctx) + val resultCode = f(childGen.value) + if (nullable) { - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; + val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) + s""" + ${childGen.code} + boolean ${ev.isNull} = ${childGen.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${eval.isNull}) { - ${f(eval.value)} - } + $nullSafeEval """ } else { ev.isNull = "false" - eval.code + s""" + s""" + ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${f(eval.value)} + $resultCode """ } } @@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: (String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(eval1.value, eval2.value) + val leftGen = left.gen(ctx) + val rightGen = right.gen(ctx) + val resultCode = f(leftGen.value, rightGen.value) + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { + rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; + boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } + $nullSafeEval """ - } else { ev.isNull = "false" s""" - ${eval1.code} - ${eval2.code} + ${leftGen.code} + ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode """ @@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression { /** * Default behavior of evaluation according to the default nullability of TernaryExpression. - * If subclass of BinaryExpression override nullable, probably should also override this. + * If subclass of TernaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { val exprs = children @@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression { sys.error(s"BinaryExpressions must override either eval or nullSafeEval") /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f accepts two variable names and returns Java code to compute the output. + * @param f accepts three variable names and returns Java code to compute the output. */ protected def defineCodeGen( ctx: CodegenContext, @@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression { } /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f function that accepts the 2 non-null evaluation result names of children + * @param f function that accepts the 3 non-null evaluation result names of children * and returns Java code to compute the output. */ protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, f: (String, String, String) => String): String = { - val evals = children.map(_.gen(ctx)) - val resultCode = f(evals(0).value, evals(1).value, evals(2).value) + val leftGen = children(0).gen(ctx) + val midGen = children(1).gen(ctx) + val rightGen = children(2).gen(ctx) + val resultCode = f(leftGen.value, midGen.value, rightGen.value) + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) { + midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) { + rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + } + s""" - ${evals(0).code} boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode - } - } - } + $nullSafeEval """ } else { ev.isNull = "false" s""" - ${evals(0).code} - ${evals(1).code} - ${evals(2).code} + ${leftGen.code} + ${midGen.code} + ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 21f9198073d74..a30aba16170a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,17 +402,37 @@ class CodegenContext { } /** - * Generates code for greater of two expressions. - * - * @param dataType data type of the expressions - * @param c1 name of the variable of expression 1's output - * @param c2 name of the variable of expression 2's output - */ + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code to do null safe execution, i.e. only execute the code when the input is not + * null by adding null check if necessary. + * + * @param nullable used to decide whether we should add null check or not. + * @param isNull the code to check if the input is null. + * @param execute the code that should only be executed when the input is not null. + */ + def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execute + } + """ + } else { + "\n" + execute + } + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8480c3f9a12f6..36e1fa1176d22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ev.isNull = "false" val childrenHash = children.map { child => val childGen = child.gen(ctx) - childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } }.mkString("\n") @@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression """ } - private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { - if (nullable) { - s""" - if (!$isNull) { - $execution - } - """ - } else { - "\n" + execution - } - } - private def nullSafeElementHash( input: String, index: String, @@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ctx: CodegenContext): String = { val element = ctx.freshName("element") - generateNullCheck(nullable, s"$input.isNullAt($index)") { + ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} From 6075573a93176ee8c071888e4525043d9e73b061 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 1 Feb 2016 11:02:17 -0800 Subject: [PATCH 033/107] [SPARK-6847][CORE][STREAMING] Fix stack overflow issue when updateStateByKey is followed by a checkpointed dstream Add a local property to indicate if checkpointing all RDDs that are marked with the checkpoint flag, and enable it in Streaming Author: Shixiong Zhu Closes #10934 from zsxwing/recursive-checkpoint. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 19 +++++ .../org/apache/spark/CheckpointSuite.scala | 21 ++++++ .../streaming/scheduler/JobGenerator.scala | 5 ++ .../streaming/scheduler/JobScheduler.scala | 7 +- .../spark/streaming/CheckpointSuite.scala | 69 +++++++++++++++++++ 5 files changed, 119 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index be47172581b7f..e8157cf4ebe7d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1542,6 +1542,15 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + // Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default, + // we stop as soon as we find the first such RDD, an optimization that allows us to write + // less data but is not safe for all workloads. E.g. in streaming we may checkpoint both + // an RDD and its parent in every batch, in which case the parent may never be checkpointed + // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). + private val checkpointAllMarkedAncestors = + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) + .map(_.toBoolean).getOrElse(false) + /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] @@ -1585,6 +1594,13 @@ abstract class RDD[T: ClassTag]( if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { + if (checkpointAllMarkedAncestors) { + // TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint + // them in parallel. + // Checkpoint parents first because our lineage will be truncated after we + // checkpoint ourselves + dependencies.foreach(_.rdd.doCheckpoint()) + } checkpointData.get.checkpoint() } else { dependencies.foreach(_.rdd.doCheckpoint()) @@ -1704,6 +1720,9 @@ abstract class RDD[T: ClassTag]( */ object RDD { + private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS = + "spark.checkpoint.checkpointAllMarkedAncestors" + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 390764ba242fd..ce35856dce3f7 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -512,6 +512,27 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } + + runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean => + testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true) + testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false) + } + + private def testCheckpointAllMarkedAncestors( + reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = { + sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString) + try { + val rdd1 = sc.parallelize(1 to 10) + checkpoint(rdd1, reliableCheckpoint) + val rdd2 = rdd1.map(_ + 1) + checkpoint(rdd2, reliableCheckpoint) + rdd2.count() + assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors) + assert(rdd2.isCheckpointed === true) + } finally { + sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null) + } + } } /** RDD partition that has large serialized size. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index a5a01e77639c4..a3ad5eaa40edc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} @@ -243,6 +244,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Example: BlockRDDs are created in this thread, and it needs to access BlockManager // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. SparkEnv.set(ssc.env) + + // Checkpoint all RDDs marked for checkpointing to ensure their lineages are + // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). + ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") Try { jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch graph.generateJobs(time) // generate jobs using allocated block diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 9535c8e5b768a..3fed3d88354c7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -23,10 +23,10 @@ import scala.collection.JavaConverters._ import scala.util.Failure import org.apache.spark.Logging -import org.apache.spark.rdd.PairRDDFunctions +import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -210,6 +210,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { s"""Streaming job from $batchLinkText""") ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) + // Checkpoint all RDDs marked for checkpointing to ensure their lineages are + // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). + ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") // We need to assign `eventLoop` to a temp variable. Otherwise, because // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 4a6b91fbc745e..786703eb9a84e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -821,6 +821,75 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester checkpointWriter.stop() } + test("SPARK-6847: stack overflow when updateStateByKey is followed by a checkpointed dstream") { + // In this test, there are two updateStateByKey operators. The RDD DAG is as follows: + // + // batch 1 batch 2 batch 3 ... + // + // 1) input rdd input rdd input rdd + // | | | + // v v v + // 2) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 3) map rdd --- map rdd --- map rdd ... + // | | | + // v v v + // 4) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 5) map rdd --- map rdd --- map rdd ... + // + // Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to + // break the RDD chain. However, before SPARK-6847, when the state RDD (layer 5) of the second + // "updateStateByKey" does checkpoint, it won't checkpoint the state RDD (layer 3) of the first + // "updateStateByKey" (Note: "updateStateByKey" has already marked that its state RDD (layer 3) + // should be checkpointed). Hence, the connections between layer 2 and layer 3 won't be broken + // and the RDD chain will grow infinitely and cause StackOverflow. + // + // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing + // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break + // connections between layer 2 and layer 3) + ssc = new StreamingContext(master, framework, batchDuration) + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + @volatile var shouldCheckpointAllMarkedRDDs = false + @volatile var rddsCheckpointed = false + inputDStream.map(i => (i, i)) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .foreachRDD { rdd => + /** + * Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors. + */ + def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = { + val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList + if (rdd.checkpointData.isDefined) { + rdd :: markedRDDs + } else { + markedRDDs + } + } + + shouldCheckpointAllMarkedRDDs = + Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)). + map(_.toBoolean).getOrElse(false) + + val stateRDDs = findAllMarkedRDDs(rdd) + rdd.count() + // Check the two state RDDs are both checkpointed + rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed) + } + ssc.start() + batchCounter.waitUntilBatchesCompleted(1, 10000) + assert(shouldCheckpointAllMarkedRDDs === true) + assert(rddsCheckpointed === true) + } + /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. From 33c8a490f7f64320c53530a57bd8d34916e3607c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 1 Feb 2016 11:22:02 -0800 Subject: [PATCH 034/107] [SPARK-12989][SQL] Delaying Alias Cleanup after ExtractWindowExpressions JIRA: https://issues.apache.org/jira/browse/SPARK-12989 In the rule `ExtractWindowExpressions`, we simply replace alias by the corresponding attribute. However, this will cause an issue exposed by the following case: ```scala val data = Seq(("a", "b", "c", 3), ("c", "b", "a", 3)).toDF("A", "B", "C", "num") .withColumn("Data", struct("A", "B", "C")) .drop("A") .drop("B") .drop("C") val winSpec = Window.partitionBy("Data.A", "Data.B").orderBy($"num".desc) data.select($"*", max("num").over(winSpec) as "max").explain(true) ``` In this case, both `Data.A` and `Data.B` are `alias` in `WindowSpecDefinition`. If we replace these alias expression by their alias names, we are unable to know what they are since they will not be put in `missingExpr` too. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10963 from gatorsmile/seletStarAfterColDrop. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++-- .../org/apache/spark/sql/DataFrameWindowSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5fe700ee00673..ee60fca1ad4fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -883,12 +883,13 @@ class Analyzer( if (missingExpr.nonEmpty) { extractedExprBuffer += ne } - ne.toAttribute + // alias will be cleaned in the rule CleanupAliases + ne case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. case e: Expression => // For other expressions, we extract it and replace it with an AttributeReference (with - // an interal column name, e.g. "_w0"). + // an internal column name, e.g. "_w0"). val withName = Alias(e, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index d38842c3c0cf0..2bcbb1983f7ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -344,4 +344,14 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row("b", 1, null, null, null, null, null, null), Row("b", 2, null, null, null, null, null, null))) } + + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { + val src = Seq((0, 3, 5)).toDF("a", "b", "c") + .withColumn("Data", struct("a", "b")) + .drop("a") + .drop("b") + val winSpec = Window.partitionBy("Data.a", "Data.b").orderBy($"c".desc) + val df = src.select($"*", max("c").over(winSpec) as "max") + checkAnswer(df, Row(5, Row(0, 3), 5)) + } } From 8f26eb5ef6853a6666d7d9481b333de70bc501ed Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 1 Feb 2016 11:57:13 -0800 Subject: [PATCH 035/107] [SPARK-12705][SPARK-10777][SQL] Analyzer Rule ResolveSortReferences JIRA: https://issues.apache.org/jira/browse/SPARK-12705 **Scope:** This PR is a general fix for sorting reference resolution when the child's `outputSet` does not have the order-by attributes (called, *missing attributes*): - UnaryNode support is limited to `Project`, `Window`, `Aggregate`, `Distinct`, `Filter`, `RepartitionByExpression`. - We will not try to resolve the missing references inside a subquery, unless the outputSet of this subquery contains it. **General Reference Resolution Rules:** - Jump over the nodes with the following types: `Distinct`, `Filter`, `RepartitionByExpression`. Do not need to add missing attributes. The reason is their `outputSet` is decided by their `inputSet`, which is the `outputSet` of their children. - Group-by expressions in `Aggregate`: missing order-by attributes are not allowed to be added into group-by expressions since it will change the query result. Thus, in RDBMS, it is not allowed. - Aggregate expressions in `Aggregate`: if the group-by expressions in `Aggregate` contains the missing attributes but aggregate expressions do not have it, just add them into the aggregate expressions. This can resolve the analysisExceptions thrown by the three TCPDS queries. - `Project` and `Window` are special. We just need to add the missing attributes to their `projectList`. **Implementation:** 1. Traverse the whole tree in a pre-order manner to find all the resolvable missing order-by attributes. 2. Traverse the whole tree in a post-order manner to add the found missing order-by attributes to the node if their `inputSet` contains the attributes. 3. If the origins of the missing order-by attributes are different nodes, each pass only resolves the missing attributes that are from the same node. **Risk:** Low. This rule will be trigger iff ```!s.resolved && child.resolved``` is true. Thus, very few cases are affected. Author: gatorsmile Closes #10678 from gatorsmile/sortWindows. --- .../sql/catalyst/analysis/Analyzer.scala | 101 ++++++++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 83 ++++++++++++++ .../sql/catalyst/analysis/TestRelations.scala | 6 ++ .../apache/spark/sql/DataFrameJoinSuite.scala | 16 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ .../sql/hive/execution/SQLQuerySuite.scala | 84 ++++++++++++++- 6 files changed, 274 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ee60fca1ad4fe..a983dc1cdfebe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -452,7 +453,7 @@ class Analyzer( i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as - // we still have chance to resolve it based on grandchild + // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) @@ -533,38 +534,96 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(ordering, global, p @ Project(projectList, child)) - if !s.resolved && p.resolved => - val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) + // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, child: Aggregate) => sa - // If this rule was not a no-op, return the transformed plan, otherwise return the original. - if (missing.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(p.output, - Sort(newOrdering, global, - Project(projectList ++ missing, child))) - } else { - logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") + case s @ Sort(_, _, child) if !s.resolved && child.resolved => + val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) + + if (missingResolvableAttrs.isEmpty) { + val unresolvableAttrs = s.order.filterNot(_.resolved) + logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}") s // Nothing we can do here. Return original plan. + } else { + // Add the missing attributes into projectList of Project/Window or + // aggregateExpressions of Aggregate, if they are in the inputSet + // but not in the outputSet of the plan. + val newChild = child transformUp { + case p: Project => + p.copy(projectList = p.projectList ++ + missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains)) + case w: Window => + w.copy(projectList = w.projectList ++ + missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains)) + case a: Aggregate => + val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains) + val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains) + val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs + a.copy(aggregateExpressions = newAggregateExpressions) + case o => o + } + + // Add missing attributes and then project them away after the sort. + Project(child.output, + Sort(newOrdering, s.global, newChild)) } } /** - * Given a child and a grandchild that are present beneath a sort operator, try to resolve - * the sort ordering and returns it with a list of attributes that are missing from the - * child but are present in the grandchild. + * Traverse the tree until resolving the sorting attributes + * Return all the resolvable missing sorting attributes + */ + @tailrec + private def collectResolvableMissingAttrs( + ordering: Seq[SortOrder], + plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + plan match { + // Only Windows and Project have projectList-like attribute. + case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) + // If missingAttrs is non empty, that means we got it and return it; + // Otherwise, continue to traverse the tree. + if (missingAttrs.nonEmpty) { + (newOrdering, missingAttrs) + } else { + collectResolvableMissingAttrs(ordering, un.child) + } + case a: Aggregate => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) + // For Aggregate, all the order by columns must be specified in group by clauses + if (missingAttrs.nonEmpty && + missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { + (newOrdering, missingAttrs) + } else { + // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes + (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + // Jump over the following UnaryNode types + // The output of these types is the same as their child's output + case _: Distinct | + _: Filter | + _: RepartitionByExpression => + collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child) + // If hitting the other unsupported operators, we are unable to resolve it. + case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + } + + /** + * Try to resolve the sort ordering and returns it with a list of attributes that are missing + * from the plan but are present in the child. */ - def resolveAndFindMissing( + private def resolveAndFindMissing( ordering: Seq[SortOrder], - child: LogicalPlan, - grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) + plan: LogicalPlan, + child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + val newOrdering = resolveSortOrders(ordering, child, throws = false) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) // Figure out which ones are missing from the projection, so that we can add them and // remove them after the sort. - val missingInProject = requiredAttributes -- child.output + val missingInProject = requiredAttributes -- plan.outputSet // It is important to return the new SortOrders here, instead of waiting for the standard // resolving process as adding attributes to the project below can actually introduce // ambiguity that was not present before. @@ -719,7 +778,7 @@ class Analyzer( } } - protected def containsAggregate(condition: Expression): Boolean = { + def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1938bce02a177..ebf885a8fe484 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest { caseSensitive = false) } + test("resolve sort references - filter/limit") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + + // Case 1: one missing attribute is in the leaf node and another is in the unary node + val plan1 = testRelation2 + .where('a > "str").select('a, 'b) + .where('b > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected1 = testRelation2 + .where(a > "str").select(a, b, c) + .where(b > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a, b).select(a) + checkAnalysis(plan1, expected1) + + // Case 2: all the missing attributes are in the leaf node + val plan2 = testRelation2 + .where('a > "str").select('a) + .where('a > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected2 = testRelation2 + .where(a > "str").select(a, b, c) + .where(a > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a) + checkAnalysis(plan2, expected2) + } + + test("resolve sort references - join") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val h = testRelation3.output(3) + + // Case: join itself can resolve all the missing attributes + val plan = testRelation2.join(testRelation3) + .where('a > "str").select('a, 'b) + .sortBy('c.desc, 'h.asc) + val expected = testRelation2.join(testRelation3) + .where(a > "str").select(a, b, c, h) + .sortBy(c.desc, h.asc) + .select(a, b) + checkAnalysis(plan, expected) + } + + test("resolve sort references - aggregate") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val alias_a3 = count(a).as("a3") + val alias_b = b.as("aggOrder") + + // Case 1: when the child of Sort is not Aggregate, + // the sort reference is handled by the rule ResolveSortReferences + val plan1 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .select('a, 'c, 'a3) + .orderBy('b.asc) + + val expected1 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, b) + .select(a, c, alias_a3.toAttribute, b) + .orderBy(b.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan1, expected1) + + // Case 2: when the child of Sort is Aggregate, + // the sort reference is handled by the rule ResolveAggregateFunctions + val plan2 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .orderBy('b.asc) + + val expected2 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, alias_b) + .orderBy(alias_b.toAttribute.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan2, expected2) + } + test("resolve relations") { assertAnalysisError( UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index bc07b609a3413..3741a6ba95a86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -31,6 +31,12 @@ object TestRelations { AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) + val testRelation3 = LocalRelation( + AttributeReference("e", ShortType)(), + AttributeReference("f", StringType)(), + AttributeReference("g", DoubleType)(), + AttributeReference("h", DecimalType(10, 2))()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c17be8ace9287..a5e5f156423cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - sorted columns not in join's outputSet") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1) + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2) + val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3) + + checkAnswer( + df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2") + .orderBy('str_sort.asc, 'str.asc), + Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil) + + checkAnswer( + df2.join(df3, $"df2.int" === $"df3.int", "inner") + .select($"df2.int", $"df3.int").orderBy($"df2.str".desc), + Row(5, 5) :: Row(1, 1) :: Nil) + } + test("join - join using multiple columns and specifying join type") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4ff99bdf2937d..c02133ffc8540 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -954,6 +954,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(expected === actual) } + test("Sorting columns are not in Filter and Project") { + checkAnswer( + upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc), + Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) + } + test("SPARK-9323: DataFrame.orderBy should support nested column name") { val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1ada2e325bda6..6048b8f5a3998 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -736,7 +736,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin), (2 to 6).map(i => Row(i))) } - test("window function: udaf with aggregate expressin") { + test("window function: udaf with aggregate expression") { val data = Seq( WindowData(1, "a", 5), WindowData(2, "a", 6), @@ -927,6 +927,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: Sorting columns are not in Project") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql("select month, product, sum(product + 1) over() from windowData order by area"), + Seq( + (2, 6, 57), + (3, 7, 57), + (4, 8, 57), + (5, 9, 57), + (6, 11, 57), + (1, 10, 57) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p + """.stripMargin), + Seq( + ("a", 2), + ("b", 2), + ("b", 3), + ("c", 2), + ("d", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by month) as c1 + |from windowData group by product, area, month order by product, area + """.stripMargin), + Seq( + ("a", 1), + ("b", 1), + ("b", 2), + ("c", 1), + ("d", 1), + ("c", 2) + ).map(i => Row(i._1, i._2))) + } + + // todo: fix this test case by reimplementing the function ResolveAggregateFunctions + ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product) over () as c from windowData + |where product > 3 group by area, product + |having avg(month) > 0 order by avg(month), product + """.stripMargin), + Seq( + ("a", 51), + ("b", 51), + ("b", 51), + ("c", 51), + ("c", 51), + ("d", 51) + ).map(i => Row(i._1, i._2))) + } + test("window function: multiple window expressions in a single expression") { val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums") From da9146c91a33577ff81378ca7e7c38a4b1917876 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 1 Feb 2016 12:02:06 -0800 Subject: [PATCH 036/107] [DOCS] Fix the jar location of datanucleus in sql-programming-guid.md ISTM `lib` is better because `datanucleus` jars are located in `lib` for release builds. Author: Takeshi YAMAMURO Closes #10901 from maropu/DocFix. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fddc51379406b..550a40010e828 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1695,7 +1695,7 @@ on all of the worker nodes, as they will need access to the Hive serialization a Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running -the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the `spark-submit` command. From 711ce048a285403241bbc9eaabffc1314162e89c Mon Sep 17 00:00:00 2001 From: Lewuathe Date: Mon, 1 Feb 2016 12:21:21 -0800 Subject: [PATCH 037/107] [ML][MINOR] Invalid MulticlassClassification reference in ml-guide In [ml-guide](https://spark.apache.org/docs/latest/ml-guide.html#example-model-selection-via-cross-validation), there is invalid reference to `MulticlassClassificationEvaluator` apidoc. https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator Author: Lewuathe Closes #10996 from Lewuathe/fix-typo-in-ml-guide. --- docs/ml-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 5aafd53b584e7..f8279262e673f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -627,7 +627,7 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator) for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetricName` method in each of these evaluators. From 51b03b71ffc390e67b32936efba61e614a8b0d86 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Mon, 1 Feb 2016 12:45:02 -0800 Subject: [PATCH 038/107] [SPARK-12463][SPARK-12464][SPARK-12465][SPARK-10647][MESOS] Fix zookeeper dir with mesos conf and add docs. Fix zookeeper dir configuration used in cluster mode, and also add documentation around these settings. Author: Timothy Chen Closes #10057 from tnachen/fix_mesos_dir. --- .../deploy/mesos/MesosClusterDispatcher.scala | 6 ++--- .../mesos/MesosClusterPersistenceEngine.scala | 4 ++-- docs/configuration.md | 23 +++++++++++++++++++ docs/running-on-mesos.md | 5 +++- docs/spark-standalone.md | 23 ++++--------------- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 66e1e645007a7..9b31497adfb12 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -50,7 +50,7 @@ private[mesos] class MesosClusterDispatcher( extends Logging { private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) - private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + private val recoveryMode = conf.get("spark.deploy.recoveryMode", "NONE").toUpperCase() logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) private val engineFactory = recoveryMode match { @@ -98,8 +98,8 @@ private[mesos] object MesosClusterDispatcher extends Logging { conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => - conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") - conf.set("spark.mesos.deploy.zookeeper.url", z) + conf.set("spark.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.deploy.zookeeper.url", z) } val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) dispatcher.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index e0c547dce6d07..092d9e4182530 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -53,9 +53,9 @@ private[spark] trait MesosClusterPersistenceEngine { * all of them reuses the same connection pool. */ private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) - extends MesosClusterPersistenceEngineFactory(conf) { + extends MesosClusterPersistenceEngineFactory(conf) with Logging { - lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + lazy val zk = SparkCuratorUtil.newClient(conf) def createEngine(path: String): MesosClusterPersistenceEngine = { new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) diff --git a/docs/configuration.md b/docs/configuration.md index 74a8fb5d35a66..93b399d819ccd 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1585,6 +1585,29 @@ Apart from these, the following properties are also available, and may be useful +#### Deploy + + + + + + + + + + + + + + + + + + +
        Property NameDefaultMeaniing
        spark.deploy.recoveryModeNONEThe recovery mode setting to recover submitted Spark jobs with cluster mode when it failed and relaunches. + This is only applicable for cluster mode when running with Standalone or Mesos.
        spark.deploy.zookeeper.urlNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper URL to connect to.
        spark.deploy.zookeeper.dirNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper directory to store recovery state.
        + + #### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ed720f1039f94..0ef1ccb36e117 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -153,7 +153,10 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). +If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. + +The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. +For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy]. From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 2fe9ec3542b28..3de72bc016dd4 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -112,8 +112,8 @@ You can optionally configure the cluster further by setting environment variable SPARK_LOCAL_DIRS - Directory to use for "scratch" space in Spark, including map output files and RDDs that get - stored on disk. This should be on a fast, local disk in your system. It can also be a + Directory to use for "scratch" space in Spark, including map output files and RDDs that get + stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. @@ -341,23 +341,8 @@ Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.o **Configuration** -In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration: - - - - - - - - - - - - - - - -
        System propertyMeaning
        spark.deploy.recoveryModeSet to ZOOKEEPER to enable standby Master recovery mode (default: NONE).
        spark.deploy.zookeeper.urlThe ZooKeeper cluster url (e.g., 192.168.1.100:2181,192.168.1.101:2181).
        spark.deploy.zookeeper.dirThe directory in ZooKeeper to store recovery state (default: /spark).
        +In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. +For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy] Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently). From a41b68b954ba47284a1df312f0aaea29b0721b0a Mon Sep 17 00:00:00 2001 From: Nilanjan Raychaudhuri Date: Mon, 1 Feb 2016 13:33:24 -0800 Subject: [PATCH 039/107] [SPARK-12265][MESOS] Spark calls System.exit inside driver instead of throwing exception This takes over #10729 and makes sure that `spark-shell` fails with a proper error message. There is a slight behavioral change: before this change `spark-shell` would exit, while now the REPL is still there, but `sc` and `sqlContext` are not defined and the error is visible to the user. Author: Nilanjan Raychaudhuri Author: Iulian Dragos Closes #10921 from dragos/pr/10729. --- .../cluster/mesos/MesosClusterScheduler.scala | 1 + .../cluster/mesos/MesosSchedulerBackend.scala | 1 + .../cluster/mesos/MesosSchedulerUtils.scala | 21 +++++++++++++++---- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 05fda0fded7f8..e77d77208ccba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -573,6 +573,7 @@ private[spark] class MesosClusterScheduler( override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} override def error(driver: SchedulerDriver, error: String): Unit = { logError("Error received: " + error) + markErr() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index eaf0cb06d6c73..a8bf79a78cf5d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -375,6 +375,7 @@ private[spark] class MesosSchedulerBackend( override def error(d: SchedulerDriver, message: String) { inClassLoader() { logError("Mesos error: " + message) + markErr() scheduler.error(message) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 010caff3e39b2..f9f5da9bc8df6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -106,28 +106,37 @@ private[mesos] trait MesosSchedulerUtils extends Logging { registerLatch.await() return } + @volatile + var error: Option[Exception] = None + // We create a new thread that will block inside `mesosDriver.run` + // until the scheduler exists new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { setDaemon(true) - override def run() { - mesosDriver = newDriver try { + mesosDriver = newDriver val ret = mesosDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { - System.exit(1) + error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) + markErr() } } catch { case e: Exception => { logError("driver.run() failed", e) - System.exit(1) + error = Some(e) + markErr() } } } }.start() registerLatch.await() + + // propagate any error to the calling thread. This ensures that SparkContext creation fails + // without leaving a broken context that won't be able to schedule any tasks + error.foreach(throw _) } } @@ -144,6 +153,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { registerLatch.countDown() } + protected def markErr(): Unit = { + registerLatch.countDown() + } + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { val builder = Resource.newBuilder() .setName(name) From c9b89a0a0921ce3d52864afd4feb7f37b90f7b46 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Mon, 1 Feb 2016 13:38:38 -0800 Subject: [PATCH 040/107] =?UTF-8?q?[SPARK-12979][MESOS]=20Don=E2=80=99t=20?= =?UTF-8?q?resolve=20paths=20on=20the=20local=20file=20system=20in=20Mesos?= =?UTF-8?q?=20scheduler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The driver filesystem is likely different from where the executors will run, so resolving paths (and symlinks, etc.) will lead to invalid paths on executors. Author: Iulian Dragos Closes #10923 from dragos/issue/canonical-paths. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosClusterScheduler.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 58c30e7d97886..2f095b86c69ef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -179,7 +179,7 @@ private[spark] class CoarseMesosSchedulerBackend( .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) if (uri.isEmpty) { - val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath + val runScript = new File(executorSparkHome, "./bin/spark-class").getPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index e77d77208ccba..8cda4ff0eb3b3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -394,7 +394,7 @@ private[spark] class MesosClusterScheduler( .getOrElse { throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } - val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getPath // Sandbox points to the current directory by default with Mesos. (cmdExecutable, ".") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index a8bf79a78cf5d..340f29bac9218 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -125,7 +125,7 @@ private[spark] class MesosSchedulerBackend( val executorBackendName = classOf[MesosExecutorBackend].getName if (uri.isEmpty) { - val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath + val executorPath = new File(executorSparkHome, "/bin/spark-class").getPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to From 064b029c6a15481fc4dfb147100c19a68cd1cc95 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 1 Feb 2016 13:56:14 -0800 Subject: [PATCH 041/107] [SPARK-13043][SQL] Implement remaining catalyst types in ColumnarBatch. This includes: float, boolean, short, decimal and calendar interval. Decimal is mapped to long or byte array depending on the size and calendar interval is mapped to a struct of int and long. The only remaining type is map. The schema mapping is straightforward but we might want to revisit how we deal with this in the rest of the execution engine. Author: Nong Li Closes #10961 from nongli/spark-13043. --- .../apache/spark/sql/types/DecimalType.scala | 22 +++ .../execution/vectorized/ColumnVector.java | 180 +++++++++++++++++- .../vectorized/ColumnVectorUtils.java | 34 +++- .../execution/vectorized/ColumnarBatch.java | 46 +++-- .../vectorized/OffHeapColumnVector.java | 98 +++++++++- .../vectorized/OnHeapColumnVector.java | 94 ++++++++- .../vectorized/ColumnarBatchSuite.scala | 44 ++++- .../org/apache/spark/unsafe/Platform.java | 8 + 8 files changed, 484 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index cf5322125bd72..5dd661ee6b339 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a long + */ + def is64BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_LONG_DIGITS + case _ => false + } + } + + /** + * Returns if dt is a DecimalType that doesn't fit inside a long + */ + def isByteArrayDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision > Decimal.MAX_LONG_DIGITS + case _ => false + } + } + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a0bf8734b6545..a5bc506a65ac2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -102,18 +105,36 @@ public Object[] array() { DataType dt = data.dataType(); Object[] list = new Object[length]; - if (dt instanceof ByteType) { + if (dt instanceof BooleanType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getBoolean(offset + i); + } + } + } else if (dt instanceof ByteType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = data.getByte(offset + i); } } + } else if (dt instanceof ShortType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getShort(offset + i); + } + } } else if (dt instanceof IntegerType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = data.getInt(offset + i); } } + } else if (dt instanceof FloatType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getFloat(offset + i); + } + } } else if (dt instanceof DoubleType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { @@ -126,12 +147,25 @@ public Object[] array() { list[i] = data.getLong(offset + i); } } + } else if (dt instanceof DecimalType) { + DecimalType decType = (DecimalType)dt; + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = getDecimal(i, decType.precision(), decType.scale()); + } + } } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); } } + } else if (dt instanceof CalendarIntervalType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = getInterval(i); + } + } } else { throw new NotImplementedException("Type " + dt); } @@ -170,7 +204,14 @@ public float getFloat(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } } @Override @@ -181,17 +222,22 @@ public UTF8String getUTF8String(int ordinal) { @Override public byte[] getBinary(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array array = data.getByteArray(offset + ordinal); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; } @Override public CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); } @Override public InternalRow getStruct(int ordinal, int numFields) { - throw new NotImplementedException(); + return data.getStruct(offset + ordinal); } @Override @@ -279,6 +325,21 @@ public void reset() { */ public abstract boolean getIsNull(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putBoolean(int rowId, boolean value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBooleans(int rowId, int count, boolean value); + + /** + * Returns the value for rowId. + */ + public abstract boolean getBoolean(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -299,6 +360,26 @@ public void reset() { */ public abstract byte getByte(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putShort(int rowId, short value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putShorts(int rowId, int count, short value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract short getShort(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -351,6 +432,33 @@ public void reset() { */ public abstract long getLong(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putFloat(int rowId, float value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putFloats(int rowId, int count, float value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * src should contain `count` doubles written as ieee format. + */ + public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted floats. + */ + public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract float getFloat(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -369,7 +477,7 @@ public void reset() { /** * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formated doubles. + * The data in src must be ieee formatted doubles. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); @@ -469,6 +577,20 @@ public final int appendNotNulls(int count) { return result; } + public final int appendBoolean(boolean v) { + reserve(elementsAppended + 1); + putBoolean(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBooleans(int count, boolean v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, v); + elementsAppended += count; + return result; + } + public final int appendByte(byte v) { reserve(elementsAppended + 1); putByte(elementsAppended, v); @@ -491,6 +613,28 @@ public final int appendBytes(int length, byte[] src, int offset) { return result; } + public final int appendShort(short v) { + reserve(elementsAppended + 1); + putShort(elementsAppended, v); + return elementsAppended++; + } + + public final int appendShorts(int count, short v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putShorts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendShorts(int length, short[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putShorts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + public final int appendInt(int v) { reserve(elementsAppended + 1); putInt(elementsAppended, v); @@ -535,6 +679,20 @@ public final int appendLongs(int length, long[] src, int offset) { return result; } + public final int appendFloat(float v) { + reserve(elementsAppended + 1); + putFloat(elementsAppended, v); + return elementsAppended++; + } + + public final int appendFloats(int count, float v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putFloats(elementsAppended, count, v); + elementsAppended += count; + return result; + } + public final int appendDouble(double v) { reserve(elementsAppended + 1); putDouble(elementsAppended, v); @@ -661,7 +819,8 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.capacity = capacity; this.type = type; - if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) { + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { @@ -682,6 +841,13 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { } this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else if (type instanceof CalendarIntervalType) { + // Two columns. Months as int. Microseconds as Long. + this.childColumns = new ColumnVector[2]; + this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); + this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 6c651a759d250..453bc15e13503 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -16,12 +16,15 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Iterator; import java.util.List; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.commons.lang.NotImplementedException; @@ -59,19 +62,44 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) { private static void appendValue(ColumnVector dst, DataType t, Object o) { if (o == null) { - dst.appendNull(); + if (t instanceof CalendarIntervalType) { + dst.appendStruct(true); + } else { + dst.appendNull(); + } } else { - if (t == DataTypes.ByteType) { - dst.appendByte(((Byte)o).byteValue()); + if (t == DataTypes.BooleanType) { + dst.appendBoolean(((Boolean)o).booleanValue()); + } else if (t == DataTypes.ByteType) { + dst.appendByte(((Byte) o).byteValue()); + } else if (t == DataTypes.ShortType) { + dst.appendShort(((Short)o).shortValue()); } else if (t == DataTypes.IntegerType) { dst.appendInt(((Integer)o).intValue()); } else if (t == DataTypes.LongType) { dst.appendLong(((Long)o).longValue()); + } else if (t == DataTypes.FloatType) { + dst.appendFloat(((Float)o).floatValue()); } else if (t == DataTypes.DoubleType) { dst.appendDouble(((Double)o).doubleValue()); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(); dst.appendByteArray(b, 0, b.length); + } else if (t instanceof DecimalType) { + DecimalType dt = (DecimalType)t; + Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale()); + if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + dst.appendLong(d.toUnscaledLong()); + } else { + final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); + byte[] bytes = integer.toByteArray(); + dst.appendByteArray(bytes, 0, bytes.length); + } + } else if (t instanceof CalendarIntervalType) { + CalendarInterval c = (CalendarInterval)o; + dst.appendStruct(false); + dst.getChildColumn(0).appendInt(c.months); + dst.getChildColumn(1).appendLong(c.microseconds); } else { throw new NotImplementedException("Type " + t); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 5a575811fa896..dbad5e070f1fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -150,44 +153,40 @@ public final boolean anyNull() { } @Override - public final boolean isNullAt(int ordinal) { - return columns[ordinal].getIsNull(rowId); - } + public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); } @Override - public final boolean getBoolean(int ordinal) { - throw new NotImplementedException(); - } + public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } @Override public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } @Override - public final short getShort(int ordinal) { - throw new NotImplementedException(); - } + public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } @Override - public final int getInt(int ordinal) { - return columns[ordinal].getInt(rowId); - } + public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } @Override public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } @Override - public final float getFloat(int ordinal) { - throw new NotImplementedException(); - } + public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } @Override - public final double getDouble(int ordinal) { - return columns[ordinal].getDouble(rowId); - } + public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } } @Override @@ -198,12 +197,17 @@ public final UTF8String getUTF8String(int ordinal) { @Override public final byte[] getBinary(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array array = columns[ordinal].getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; } @Override public final CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 335124fd5a603..22c5e5fc81a4a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,11 +19,15 @@ import java.nio.ByteOrder; import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.BooleanType; import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; import org.apache.spark.sql.types.IntegerType; import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -121,6 +125,26 @@ public final boolean getIsNull(int rowId) { return Platform.getByte(null, nulls + rowId) == 1; } + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0)); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + Platform.putByte(null, data + rowId + i, v); + } + } + + @Override + public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + // // APIs dealing with Bytes // @@ -148,6 +172,34 @@ public final byte getByte(int rowId) { return Platform.getByte(null, data + rowId); } + // + // APIs dealing with shorts + // + + @Override + public final void putShort(int rowId, short value) { + Platform.putShort(null, data + 2 * rowId, value); + } + + @Override + public final void putShorts(int rowId, int count, short value) { + long offset = data + 2 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putShort(null, offset, value); + } + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, + null, data + 2 * rowId, count * 2); + } + + @Override + public final short getShort(int rowId) { + return Platform.getShort(null, data + 2 * rowId); + } + // // APIs dealing with ints // @@ -216,6 +268,41 @@ public final long getLong(int rowId) { return Platform.getLong(null, data + 8 * rowId); } + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { + Platform.putFloat(null, data + rowId * 4, value); + } + + @Override + public final void putFloats(int rowId, int count, float value) { + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putFloat(null, offset, value); + } + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, + null, data + 4 * rowId, count * 4); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + + @Override + public final float getFloat(int rowId) { + return Platform.getFloat(null, data + rowId * 4); + } + + // // APIs dealing with doubles // @@ -241,7 +328,7 @@ public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { @Override public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex, + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId * 8, count * 8); } @@ -300,11 +387,14 @@ private final void reserveInternal(int newCapacity) { Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); this.offsetData = Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); - } else if (type instanceof ByteType) { + } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); - } else if (type instanceof IntegerType) { + } else if (type instanceof ShortType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + } else if (type instanceof IntegerType || type instanceof FloatType) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); - } else if (type instanceof LongType || type instanceof DoubleType) { + } else if (type instanceof LongType || type instanceof DoubleType || + DecimalType.is64BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 8197fa11cd4c8..32356334c031f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -35,8 +35,10 @@ public final class OnHeapColumnVector extends ColumnVector { // Array for each type. Only 1 is populated for any type. private byte[] byteData; + private short[] shortData; private int[] intData; private long[] longData; + private float[] floatData; private double[] doubleData; // Only set if type is Array. @@ -104,6 +106,30 @@ public final boolean getIsNull(int rowId) { return nulls[rowId] == 1; } + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + byteData[rowId] = (byte)((value) ? 1 : 0); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + byteData[i + rowId] = v; + } + } + + @Override + public final boolean getBoolean(int rowId) { + return byteData[rowId] == 1; + } + + // + // // APIs dealing with Bytes // @@ -130,6 +156,33 @@ public final byte getByte(int rowId) { return byteData[rowId]; } + // + // APIs dealing with Shorts + // + + @Override + public final void putShort(int rowId, short value) { + shortData[rowId] = value; + } + + @Override + public final void putShorts(int rowId, int count, short value) { + for (int i = 0; i < count; ++i) { + shortData[i + rowId] = value; + } + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + System.arraycopy(src, srcIndex, shortData, rowId, count); + } + + @Override + public final short getShort(int rowId) { + return shortData[rowId]; + } + + // // APIs dealing with Ints // @@ -202,6 +255,31 @@ public final long getLong(int rowId) { return longData[rowId]; } + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { floatData[rowId] = value; } + + @Override + public final void putFloats(int rowId, int count, float value) { + Arrays.fill(floatData, rowId, rowId + count, value); + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + System.arraycopy(src, srcIndex, floatData, rowId, count); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + } + + @Override + public final float getFloat(int rowId) { return floatData[rowId]; } // // APIs dealing with doubles @@ -277,7 +355,7 @@ public final void reserve(int requiredCapacity) { // Spilt this function out since it is the slow path. private final void reserveInternal(int newCapacity) { - if (this.resultArray != null) { + if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { @@ -286,18 +364,30 @@ private final void reserveInternal(int newCapacity) { } arrayLengths = newLengths; arrayOffsets = newOffsets; + } else if (type instanceof BooleanType) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; } else if (type instanceof ByteType) { byte[] newData = new byte[newCapacity]; if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); byteData = newData; + } else if (type instanceof ShortType) { + short[] newData = new short[newCapacity]; + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + shortData = newData; } else if (type instanceof IntegerType) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; - } else if (type instanceof LongType) { + } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { long[] newData = new long[newCapacity]; if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); longData = newData; + } else if (type instanceof FloatType) { + float[] newData = new float[newCapacity]; + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + floatData = newData; } else if (type instanceof DoubleType) { double[] newData = new double[newCapacity]; if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 67cc08b6fc8ba..445f311107e33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { test("Null Apis") { @@ -571,7 +572,6 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } - private def doubleEquals(d1: Double, d2: Double): Boolean = { if (d1.isNaN && d2.isNaN) { true @@ -585,13 +585,23 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed) if (!r1.isNullAt(v._2)) { v._1.dataType match { + case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed) case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed) + case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed) case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed) case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed) + case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)), + "Seed = " + seed) case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)), "Seed = " + seed) + case t: DecimalType => + val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal + val d2 = r2.getDecimal(v._2) + assert(d1.compare(d2) == 0, "Seed = " + seed) case StringType => assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed) + case CalendarIntervalType => + assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval]) case ArrayType(childType, n) => val a1 = r1.getArray(v._2).array val a2 = r2.getList(v._2).toArray @@ -605,6 +615,27 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } } + case FloatType => { + var i = 0 + while (i < a1.length) { + assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]), + "Seed = " + seed) + i += 1 + } + } + + case t: DecimalType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal + val d2 = a2(i).asInstanceOf[java.math.BigDecimal] + assert(d1.compare(d2) == 0, "Seed = " + seed) + } + i += 1 + } + case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => @@ -644,10 +675,13 @@ class ColumnarBatchSuite extends SparkFunSuite { * results. */ def testRandomRows(flatSchema: Boolean, numFields: Int) { - // TODO: add remaining types. Figure out why StringType doesn't work on jenkins. - val types = Array(ByteType, IntegerType, LongType, DoubleType) + // TODO: Figure out why StringType doesn't work on jenkins. + val types = Array( + BooleanType, ByteType, FloatType, DoubleType, + IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10), + CalendarIntervalType) val seed = System.nanoTime() - val NUM_ROWS = 500 + val NUM_ROWS = 200 val NUM_ITERS = 1000 val random = new Random(seed) var i = 0 @@ -682,7 +716,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } test("Random flat schema") { - testRandomRows(true, 10) + testRandomRows(true, 15) } test("Random nested schema") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index b29bf6a464b30..18761bfd222a2 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -27,10 +27,14 @@ public final class Platform { public static final int BYTE_ARRAY_OFFSET; + public static final int SHORT_ARRAY_OFFSET; + public static final int INT_ARRAY_OFFSET; public static final int LONG_ARRAY_OFFSET; + public static final int FLOAT_ARRAY_OFFSET; + public static final int DOUBLE_ARRAY_OFFSET; public static int getInt(Object object, long offset) { @@ -168,13 +172,17 @@ public static void throwException(Throwable t) { if (_UNSAFE != null) { BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class); INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); } else { BYTE_ARRAY_OFFSET = 0; + SHORT_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; LONG_ARRAY_OFFSET = 0; + FLOAT_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; } } From a2973fed30fbe9a0b12e1c1225359fdf55d322b4 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 1 Feb 2016 13:57:48 -0800 Subject: [PATCH 042/107] =?UTF-8?q?Fix=20for=20[SPARK-12854][SQL]=20Implem?= =?UTF-8?q?ent=20complex=20types=20support=20in=20Columna=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rBatch Fixes build for Scala 2.11. Author: Jacek Laskowski Closes #10946 from jaceklaskowski/SPARK-12854-fix. --- .../spark/sql/execution/vectorized/OffHeapColumnVector.java | 2 +- .../spark/sql/execution/vectorized/OnHeapColumnVector.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 22c5e5fc81a4a..7a224d19d15b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -367,7 +367,7 @@ public final int putByteArray(int rowId, byte[] value, int offset, int length) { } @Override - public final void loadBytes(Array array) { + public final void loadBytes(ColumnVector.Array array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; Platform.copyMemory( null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 32356334c031f..c42bbd642ecae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -331,7 +331,7 @@ public final void putArray(int rowId, int offset, int length) { } @Override - public final void loadBytes(Array array) { + public final void loadBytes(ColumnVector.Array array) { array.byteArray = byteData; array.byteArrayOffset = array.offset; } From be7a2fc0716b7d25327b6f8f683390fc62532e3b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 14:11:52 -0800 Subject: [PATCH 043/107] [SPARK-13078][SQL] API and test cases for internal catalog This pull request creates an internal catalog API. The creation of this API is the first step towards consolidating SQLContext and HiveContext. I envision we will have two different implementations in Spark 2.0: (1) a simple in-memory implementation, and (2) an implementation based on the current HiveClient (ClientWrapper). I took a look at what Hive's internal metastore interface/implementation, and then created this API based on it. I believe this is the minimal set needed in order to achieve all the needed functionality. Author: Reynold Xin Closes #10982 from rxin/SPARK-13078. --- .../catalyst/catalog/InMemoryCatalog.scala | 246 ++++++++++++++++ .../sql/catalyst/catalog/interface.scala | 178 ++++++++++++ .../catalyst/catalog/CatalogTestCases.scala | 263 ++++++++++++++++++ .../catalog/InMemoryCatalogSuite.scala | 23 ++ 4 files changed, 710 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala new file mode 100644 index 0000000000000..9e6dfb7e9506f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException + + +/** + * An in-memory (ephemeral) implementation of the system catalog. + * + * All public methods should be synchronized for thread-safety. + */ +class InMemoryCatalog extends Catalog { + + private class TableDesc(var table: Table) { + val partitions = new mutable.HashMap[String, TablePartition] + } + + private class DatabaseDesc(var db: Database) { + val tables = new mutable.HashMap[String, TableDesc] + val functions = new mutable.HashMap[String, Function] + } + + private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] + + private def filterPattern(names: Seq[String], pattern: String): Seq[String] = { + val regex = pattern.replaceAll("\\*", ".*").r + names.filter { funcName => regex.pattern.matcher(funcName).matches() } + } + + private def existsFunction(db: String, funcName: String): Boolean = { + catalog(db).functions.contains(funcName) + } + + private def existsTable(db: String, table: String): Boolean = { + catalog(db).tables.contains(table) + } + + private def assertDbExists(db: String): Unit = { + if (!catalog.contains(db)) { + throw new AnalysisException(s"Database $db does not exist") + } + } + + private def assertFunctionExists(db: String, funcName: String): Unit = { + assertDbExists(db) + if (!existsFunction(db, funcName)) { + throw new AnalysisException(s"Function $funcName does not exists in $db database") + } + } + + private def assertTableExists(db: String, table: String): Unit = { + assertDbExists(db) + if (!existsTable(db, table)) { + throw new AnalysisException(s"Table $table does not exists in $db database") + } + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit = synchronized { + if (catalog.contains(dbDefinition.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Database ${dbDefinition.name} already exists.") + } + } else { + catalog.put(dbDefinition.name, new DatabaseDesc(dbDefinition)) + } + } + + override def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = synchronized { + if (catalog.contains(db)) { + if (!cascade) { + // If cascade is false, make sure the database is empty. + if (catalog(db).tables.nonEmpty) { + throw new AnalysisException(s"Database $db is not empty. One or more tables exist.") + } + if (catalog(db).functions.nonEmpty) { + throw new AnalysisException(s"Database $db is not empty. One or more functions exist.") + } + } + // Remove the database. + catalog.remove(db) + } else { + if (!ignoreIfNotExists) { + throw new AnalysisException(s"Database $db does not exist") + } + } + } + + override def alterDatabase(db: String, dbDefinition: Database): Unit = synchronized { + assertDbExists(db) + assert(db == dbDefinition.name) + catalog(db).db = dbDefinition + } + + override def getDatabase(db: String): Database = synchronized { + assertDbExists(db) + catalog(db).db + } + + override def listDatabases(): Seq[String] = synchronized { + catalog.keySet.toSeq + } + + override def listDatabases(pattern: String): Seq[String] = synchronized { + filterPattern(listDatabases(), pattern) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(db: String, tableDefinition: Table, ifNotExists: Boolean) + : Unit = synchronized { + assertDbExists(db) + if (existsTable(db, tableDefinition.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database") + } + } else { + catalog(db).tables.put(tableDefinition.name, new TableDesc(tableDefinition)) + } + } + + override def dropTable(db: String, table: String, ignoreIfNotExists: Boolean) + : Unit = synchronized { + assertDbExists(db) + if (existsTable(db, table)) { + catalog(db).tables.remove(table) + } else { + if (!ignoreIfNotExists) { + throw new AnalysisException(s"Table $table does not exist in $db database") + } + } + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + assertTableExists(db, oldName) + val oldDesc = catalog(db).tables(oldName) + oldDesc.table = oldDesc.table.copy(name = newName) + catalog(db).tables.put(newName, oldDesc) + catalog(db).tables.remove(oldName) + } + + override def alterTable(db: String, table: String, tableDefinition: Table): Unit = synchronized { + assertTableExists(db, table) + assert(table == tableDefinition.name) + catalog(db).tables(table).table = tableDefinition + } + + override def getTable(db: String, table: String): Table = synchronized { + assertTableExists(db, table) + catalog(db).tables(table).table + } + + override def listTables(db: String): Seq[String] = synchronized { + assertDbExists(db) + catalog(db).tables.keySet.toSeq + } + + override def listTables(db: String, pattern: String): Seq[String] = synchronized { + assertDbExists(db) + filterPattern(listTables(db), pattern) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def alterPartition(db: String, table: String, part: TablePartition) + : Unit = synchronized { + throw new UnsupportedOperationException + } + + override def alterPartitions(db: String, table: String, parts: Seq[TablePartition]) + : Unit = synchronized { + throw new UnsupportedOperationException + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction( + db: String, func: Function, ifNotExists: Boolean): Unit = synchronized { + assertDbExists(db) + + if (existsFunction(db, func.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Function $func already exists in $db database") + } + } else { + catalog(db).functions.put(func.name, func) + } + } + + override def dropFunction(db: String, funcName: String): Unit = synchronized { + assertFunctionExists(db, funcName) + catalog(db).functions.remove(funcName) + } + + override def alterFunction(db: String, funcName: String, funcDefinition: Function) + : Unit = synchronized { + assertFunctionExists(db, funcName) + if (funcName != funcDefinition.name) { + // Also a rename; remove the old one and add the new one back + catalog(db).functions.remove(funcName) + } + catalog(db).functions.put(funcName, funcDefinition) + } + + override def getFunction(db: String, funcName: String): Function = synchronized { + assertFunctionExists(db, funcName) + catalog(db).functions(funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { + assertDbExists(db) + val regex = pattern.replaceAll("\\*", ".*").r + filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala new file mode 100644 index 0000000000000..a6caf91f3304b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.AnalysisException + + +/** + * Interface for the system catalog (of columns, partitions, tables, and databases). + * + * This is only used for non-temporary items, and implementations must be thread-safe as they + * can be accessed in multiple threads. + * + * Implementations should throw [[AnalysisException]] when table or database don't exist. + */ +abstract class Catalog { + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit + + def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit + + def alterDatabase(db: String, dbDefinition: Database): Unit + + def getDatabase(db: String): Database + + def listDatabases(): Seq[String] + + def listDatabases(pattern: String): Seq[String] + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + def createTable(db: String, tableDefinition: Table, ignoreIfExists: Boolean): Unit + + def dropTable(db: String, table: String, ignoreIfNotExists: Boolean): Unit + + def renameTable(db: String, oldName: String, newName: String): Unit + + def alterTable(db: String, table: String, tableDefinition: Table): Unit + + def getTable(db: String, table: String): Table + + def listTables(db: String): Seq[String] + + def listTables(db: String, pattern: String): Seq[String] + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + // TODO: need more functions for partitioning. + + def alterPartition(db: String, table: String, part: TablePartition): Unit + + def alterPartitions(db: String, table: String, parts: Seq[TablePartition]): Unit + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + def createFunction(db: String, funcDefinition: Function, ignoreIfExists: Boolean): Unit + + def dropFunction(db: String, funcName: String): Unit + + def alterFunction(db: String, funcName: String, funcDefinition: Function): Unit + + def getFunction(db: String, funcName: String): Function + + def listFunctions(db: String, pattern: String): Seq[String] + +} + + +/** + * A function defined in the catalog. + * + * @param name name of the function + * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" + */ +case class Function( + name: String, + className: String +) + + +/** + * Storage format, used to describe how a partition or a table is stored. + */ +case class StorageFormat( + locationUri: String, + inputFormat: String, + outputFormat: String, + serde: String, + serdeProperties: Map[String, String] +) + + +/** + * A column in a table. + */ +case class Column( + name: String, + dataType: String, + nullable: Boolean, + comment: String +) + + +/** + * A partition (Hive style) defined in the catalog. + * + * @param values values for the partition columns + * @param storage storage format of the partition + */ +case class TablePartition( + values: Seq[String], + storage: StorageFormat +) + + +/** + * A table defined in the catalog. + * + * Note that Hive's metastore also tracks skewed columns. We should consider adding that in the + * future once we have a better understanding of how we want to handle skewed columns. + */ +case class Table( + name: String, + description: String, + schema: Seq[Column], + partitionColumns: Seq[Column], + sortColumns: Seq[Column], + storage: StorageFormat, + numBuckets: Int, + properties: Map[String, String], + tableType: String, + createTime: Long, + lastAccessTime: Long, + viewOriginalText: Option[String], + viewText: Option[String]) { + + require(tableType == "EXTERNAL_TABLE" || tableType == "INDEX_TABLE" || + tableType == "MANAGED_TABLE" || tableType == "VIRTUAL_VIEW") +} + + +/** + * A database defined in the catalog. + */ +case class Database( + name: String, + description: String, + locationUri: String, + properties: Map[String, String] +) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala new file mode 100644 index 0000000000000..ab9d5ac8a20eb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException + + +/** + * A reasonable complete test suite (i.e. behaviors) for a [[Catalog]]. + * + * Implementations of the [[Catalog]] interface can create test suites by extending this. + */ +abstract class CatalogTestCases extends SparkFunSuite { + + protected def newEmptyCatalog(): Catalog + + /** + * Creates a basic catalog, with the following structure: + * + * db1 + * db2 + * - tbl1 + * - tbl2 + * - func1 + */ + private def newBasicCatalog(): Catalog = { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("db1"), ifNotExists = false) + catalog.createDatabase(newDb("db2"), ifNotExists = false) + + catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) + catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + catalog + } + + private def newFunc(): Function = Function("funcname", "org.apache.spark.MyFunc") + + private def newDb(name: String = "default"): Database = + Database(name, name + " description", "uri", Map.empty) + + private def newTable(name: String): Table = + Table(name, "", Seq.empty, Seq.empty, Seq.empty, null, 0, Map.empty, "EXTERNAL_TABLE", 0, 0, + None, None) + + private def newFunc(name: String): Function = Function(name, "class.name") + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + test("basic create, drop and list databases") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb(), ifNotExists = false) + assert(catalog.listDatabases().toSet == Set("default")) + + catalog.createDatabase(newDb("default2"), ifNotExists = false) + assert(catalog.listDatabases().toSet == Set("default", "default2")) + } + + test("get database when a database exists") { + val db1 = newBasicCatalog().getDatabase("db1") + assert(db1.name == "db1") + assert(db1.description.contains("db1")) + } + + test("get database should throw exception when the database does not exist") { + intercept[AnalysisException] { newBasicCatalog().getDatabase("db_that_does_not_exist") } + } + + test("list databases without pattern") { + val catalog = newBasicCatalog() + assert(catalog.listDatabases().toSet == Set("db1", "db2")) + } + + test("list databases with pattern") { + val catalog = newBasicCatalog() + assert(catalog.listDatabases("db").toSet == Set.empty) + assert(catalog.listDatabases("db*").toSet == Set("db1", "db2")) + assert(catalog.listDatabases("*1").toSet == Set("db1")) + assert(catalog.listDatabases("db2").toSet == Set("db2")) + } + + test("drop database") { + val catalog = newBasicCatalog() + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) + assert(catalog.listDatabases().toSet == Set("db2")) + } + + test("drop database when the database is not empty") { + // Throw exception if there are functions left + val catalog1 = newBasicCatalog() + catalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false) + catalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false) + intercept[AnalysisException] { + catalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + + // Throw exception if there are tables left + val catalog2 = newBasicCatalog() + catalog2.dropFunction("db2", "func1") + intercept[AnalysisException] { + catalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + + // When cascade is true, it should drop them + val catalog3 = newBasicCatalog() + catalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) + assert(catalog3.listDatabases().toSet == Set("db1")) + } + + test("drop database when the database does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + } + + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) + } + + test("alter database") { + val catalog = newBasicCatalog() + catalog.alterDatabase("db1", Database("db1", "new description", "lll", Map.empty)) + assert(catalog.getDatabase("db1").description == "new description") + } + + test("alter database should throw exception when the database does not exist") { + intercept[AnalysisException] { + newBasicCatalog().alterDatabase("no_db", Database("no_db", "ddd", "lll", Map.empty)) + } + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + test("drop table") { + val catalog = newBasicCatalog() + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.dropTable("db2", "tbl1", ignoreIfNotExists = false) + assert(catalog.listTables("db2").toSet == Set("tbl2")) + } + + test("drop table when database / table does not exist") { + val catalog = newBasicCatalog() + + // Should always throw exception when the database does not exist + intercept[AnalysisException] { + catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = false) + } + + intercept[AnalysisException] { + catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = true) + } + + // Should throw exception when the table does not exist, if ignoreIfNotExists is false + intercept[AnalysisException] { + catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = false) + } + + catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = true) + } + + test("rename table") { + val catalog = newBasicCatalog() + + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.renameTable("db2", "tbl1", "tblone") + assert(catalog.listTables("db2").toSet == Set("tblone", "tbl2")) + } + + test("rename table when database / table does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { // Throw exception when the database does not exist + catalog.renameTable("unknown_db", "unknown_table", "unknown_table") + } + + intercept[AnalysisException] { // Throw exception when the table does not exist + catalog.renameTable("db2", "unknown_table", "unknown_table") + } + } + + test("alter table") { + val catalog = newBasicCatalog() + catalog.alterTable("db2", "tbl1", newTable("tbl1").copy(createTime = 10)) + assert(catalog.getTable("db2", "tbl1").createTime == 10) + } + + test("alter table when database / table does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { // Throw exception when the database does not exist + catalog.alterTable("unknown_db", "unknown_table", newTable("unknown_table")) + } + + intercept[AnalysisException] { // Throw exception when the table does not exist + catalog.alterTable("db2", "unknown_table", newTable("unknown_table")) + } + } + + test("get table") { + assert(newBasicCatalog().getTable("db2", "tbl1").name == "tbl1") + } + + test("get table when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getTable("unknown_db", "unknown_table") + } + + intercept[AnalysisException] { + catalog.getTable("db2", "unknown_table") + } + } + + test("list tables without pattern") { + val catalog = newBasicCatalog() + assert(catalog.listTables("db1").toSet == Set.empty) + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + } + + test("list tables with pattern") { + val catalog = newBasicCatalog() + + // Test when database does not exist + intercept[AnalysisException] { catalog.listTables("unknown_db") } + + assert(catalog.listTables("db1", "*").toSet == Set.empty) + assert(catalog.listTables("db2", "*").toSet == Set("tbl1", "tbl2")) + assert(catalog.listTables("db2", "tbl*").toSet == Set("tbl1", "tbl2")) + assert(catalog.listTables("db2", "*1").toSet == Set("tbl1")) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + // TODO: Add tests cases for partitions + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + // TODO: Add tests cases for functions +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala new file mode 100644 index 0000000000000..871f0a0f46a22 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +/** Test suite for the [[InMemoryCatalog]]. */ +class InMemoryCatalogSuite extends CatalogTestCases { + override protected def newEmptyCatalog(): Catalog = new InMemoryCatalog +} From 715a19d56fc934d4aec5025739ff650daf4580b7 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 1 Feb 2016 16:23:17 -0800 Subject: [PATCH 044/107] [SPARK-12637][CORE] Print stage info of finished stages properly Improve printing of StageInfo in onStageCompleted See also https://github.com/apache/spark/pull/10585 Author: Sean Owen Closes #10922 from srowen/SPARK-12637. --- .../org/apache/spark/scheduler/SparkListener.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index ed3adbd81c28e..7b09c2eded0be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -270,7 +270,7 @@ class StatsReportListener extends SparkListener with Logging { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { implicit val sc = stageCompleted - this.logInfo("Finished stage: " + stageCompleted.stageInfo) + this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) // Shuffle write @@ -297,6 +297,17 @@ class StatsReportListener extends SparkListener with Logging { taskInfoMetrics.clear() } + private def getStatusDetail(info: StageInfo): String = { + val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("") + val timeTaken = info.submissionTime.map( + x => info.completionTime.getOrElse(System.currentTimeMillis()) - x + ).getOrElse("-") + + s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + + s"Took: $timeTaken msec" + } + } private[spark] object StatsReportListener extends Logging { From 0df3cfb8ab4d584c95db6c340694e199d7b59e9e Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 1 Feb 2016 16:55:21 -0800 Subject: [PATCH 045/107] [SPARK-12790][CORE] Remove HistoryServer old multiple files format Removed isLegacyLogDirectory code path and updated tests andrewor14 Author: felixcheung Closes #10860 from felixcheung/historyserverformat. --- .rat-excludes | 12 +- .../deploy/history/FsHistoryProvider.scala | 124 ++---------------- .../scheduler/EventLoggingListener.scala | 2 - .../EVENT_LOG_1 => local-1422981759269} | 0 .../local-1422981759269/APPLICATION_COMPLETE | 0 .../local-1422981759269/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1422981780767} | 0 .../local-1422981780767/APPLICATION_COMPLETE | 0 .../local-1422981780767/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1425081759269} | 0 .../local-1425081759269/APPLICATION_COMPLETE | 0 .../local-1425081759269/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1426533911241} | 0 .../local-1426533911241/APPLICATION_COMPLETE | 0 .../local-1426533911241/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1426633911242} | 0 .../local-1426633911242/APPLICATION_COMPLETE | 0 .../local-1426633911242/SPARK_VERSION_1.2.0 | 0 .../history/FsHistoryProviderSuite.scala | 95 +------------- .../deploy/history/HistoryServerSuite.scala | 25 +--- 20 files changed, 23 insertions(+), 235 deletions(-) rename core/src/test/resources/spark-events/{local-1422981759269/EVENT_LOG_1 => local-1422981759269} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1422981780767/EVENT_LOG_1 => local-1422981780767} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1425081759269/EVENT_LOG_1 => local-1425081759269} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1426533911241/EVENT_LOG_1 => local-1426533911241} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1426633911242/EVENT_LOG_1 => local-1426633911242} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 diff --git a/.rat-excludes b/.rat-excludes index 874a6ee9f4043..8b5061415ff4c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -73,12 +73,12 @@ logs .*dependency-reduced-pom.xml known_translations json_expectation -local-1422981759269/* -local-1422981780767/* -local-1425081759269/* -local-1426533911241/* -local-1426633911242/* -local-1430917381534/* +local-1422981759269 +local-1422981780767 +local-1425081759269 +local-1426533911241 +local-1426633911242 +local-1430917381534 local-1430917381535_1 local-1430917381535_2 DESCRIPTION diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 22e4155cc5452..9648959dbacb9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -248,9 +248,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val logInfos: Seq[FileStatus] = statusList .filter { entry => try { - getModificationTime(entry).map { time => - time >= lastScanTime - }.getOrElse(false) + !entry.isDirectory() && (entry.getModificationTime() >= lastScanTime) } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on @@ -261,9 +259,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => - val mod1 = getModificationTime(entry1).getOrElse(-1L) - val mod2 = getModificationTime(entry2).getOrElse(-1L) - mod1 >= mod2 + entry1.getModificationTime() >= entry2.getModificationTime() } logInfos.grouped(20) @@ -341,19 +337,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get }.foreach { attempt => val logPath = new Path(logDir, attempt.logPath) - // If this is a legacy directory, then add the directory to the zipStream and add - // each file to that directory. - if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { - val files = fs.listStatus(logPath) - zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) - zipStream.closeEntry() - files.foreach { file => - val path = file.getPath - zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) - } - } else { - zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) - } + zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) } } finally { zipStream.close() @@ -527,12 +511,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") - val logInput = - if (isLegacyLogDirectory(eventLog)) { - openLegacyEventLog(logPath) - } else { - EventLoggingListener.openEventLog(logPath, fs) - } + val logInput = EventLoggingListener.openEventLog(logPath, fs) try { val appListener = new ApplicationEventListener val appCompleted = isApplicationCompleted(eventLog) @@ -540,9 +519,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.replay(logInput, logPath.toString, !appCompleted) // Without an app ID, new logs will render incorrectly in the listing page, so do not list or - // try to show their UI. Some old versions of Spark generate logs without an app ID, so let - // logs generated by those versions go through. - if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + // try to show their UI. + if (appListener.appId.isDefined) { Some(new FsApplicationAttemptInfo( logPath.getName(), appListener.appName.getOrElse(NOT_STARTED), @@ -550,7 +528,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appListener.appAttemptId, appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, + eventLog.getModificationTime(), appListener.sparkUser.getOrElse(NOT_STARTED), appCompleted)) } else { @@ -561,91 +539,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** - * Loads a legacy log directory. This assumes that the log directory contains a single event - * log file (along with other metadata files), which is the case for directories generated by - * the code in previous releases. - * - * @return input stream that holds one JSON record per line. - */ - private[history] def openLegacyEventLog(dir: Path): InputStream = { - val children = fs.listStatus(dir) - var eventLogPath: Path = null - var codecName: Option[String] = None - - children.foreach { child => - child.getPath().getName() match { - case name if name.startsWith(LOG_PREFIX) => - eventLogPath = child.getPath() - case codec if codec.startsWith(COMPRESSION_CODEC_PREFIX) => - codecName = Some(codec.substring(COMPRESSION_CODEC_PREFIX.length())) - case _ => - } - } - - if (eventLogPath == null) { - throw new IllegalArgumentException(s"$dir is not a Spark application log directory.") - } - - val codec = try { - codecName.map { c => CompressionCodec.createCodec(conf, c) } - } catch { - case e: Exception => - throw new IllegalArgumentException(s"Unknown compression codec $codecName.") - } - - val in = new BufferedInputStream(fs.open(eventLogPath)) - codec.map(_.compressedInputStream(in)).getOrElse(in) - } - - /** - * Return whether the specified event log path contains a old directory-based event log. - * Previously, the event log of an application comprises of multiple files in a directory. - * As of Spark 1.3, these files are consolidated into a single one that replaces the directory. - * See SPARK-2261 for more detail. - */ - private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDirectory - - /** - * Returns the modification time of the given event log. If the status points at an empty - * directory, `None` is returned, indicating that there isn't an event log at that location. - */ - private def getModificationTime(fsEntry: FileStatus): Option[Long] = { - if (isLegacyLogDirectory(fsEntry)) { - val statusList = fs.listStatus(fsEntry.getPath) - if (!statusList.isEmpty) Some(statusList.map(_.getModificationTime()).max) else None - } else { - Some(fsEntry.getModificationTime()) - } - } - /** * Return true when the application has completed. */ private def isApplicationCompleted(entry: FileStatus): Boolean = { - if (isLegacyLogDirectory(entry)) { - fs.exists(new Path(entry.getPath(), APPLICATION_COMPLETE)) - } else { - !entry.getPath().getName().endsWith(EventLoggingListener.IN_PROGRESS) - } - } - - /** - * Returns whether the version of Spark that generated logs records app IDs. App IDs were added - * in Spark 1.1. - */ - private def sparkVersionHasAppId(entry: FileStatus): Boolean = { - if (isLegacyLogDirectory(entry)) { - fs.listStatus(entry.getPath()) - .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } - .map { status => - val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) - version != "1.0" && version != "1.1" - } - .getOrElse(true) - } else { - true - } + !entry.getPath().getName().endsWith(EventLoggingListener.IN_PROGRESS) } /** @@ -670,12 +568,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" - - // Constants used to parse Spark 1.0.0 log directories. - val LOG_PREFIX = "EVENT_LOG_" - val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 36f2b74f948f1..01fee46e73a80 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -232,8 +232,6 @@ private[spark] object EventLoggingListener extends Logging { // Suffix applied to the names of files still being written by applications. val IN_PROGRESS = ".inprogress" val DEFAULT_LOG_DIR = "/tmp/spark-events" - val SPARK_VERSION_KEY = "SPARK_VERSION" - val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC" private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) diff --git a/core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981759269 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1422981759269 diff --git a/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981780767 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1422981780767 diff --git a/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1425081759269 similarity index 100% rename from core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1425081759269 diff --git a/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426533911241 similarity index 100% rename from core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1426533911241 diff --git a/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426633911242 similarity index 100% rename from core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1426633911242 diff --git a/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 6cbf911395a84..3baa2e2ddad31 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -69,7 +69,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new File(logPath) } - test("Parse new and old application logs") { + test("Parse application logs") { val provider = new FsHistoryProvider(createTestConf()) // Write a new-style application log. @@ -95,26 +95,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc None) ) - // Write an old-style application log. - val oldAppComplete = writeOldLog("old1", "1.0", None, true, - SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None), - SparkListenerApplicationEnd(3L) - ) - - // Check for logs so that we force the older unfinished app to be loaded, to make - // sure unfinished apps are also sorted correctly. - provider.checkForLogs() - - // Write an unfinished app, old-style. - val oldAppIncomplete = writeOldLog("old2", "1.0", None, false, - SparkListenerApplicationStart("old2", None, 2L, "test", None) - ) - - // Force a reload of data from the log directory, and check that both logs are loaded. + // Force a reload of data from the log directory, and check that logs are loaded. // Take the opportunity to check that the offset checks work as expected. updateAndCheck(provider) { list => - list.size should be (5) - list.count(_.attempts.head.completed) should be (3) + list.size should be (3) + list.count(_.attempts.head.completed) should be (2) def makeAppInfo( id: String, @@ -132,11 +117,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc newAppComplete.lastModified(), "test", true)) list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) - list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L, - oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L, - -1L, oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, + list(2) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -148,38 +129,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("Parse legacy logs with compression codec set") { - val provider = new FsHistoryProvider(createTestConf()) - val testCodecs = List((classOf[LZFCompressionCodec].getName(), true), - (classOf[SnappyCompressionCodec].getName(), true), - ("invalid.codec", false)) - - testCodecs.foreach { case (codecName, valid) => - val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null - val logDir = new File(testDir, codecName) - logDir.mkdir() - createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec), - SparkListenerApplicationStart("app2", None, 2L, "test", None), - SparkListenerApplicationEnd(3L) - ) - createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName)) - - val logPath = new Path(logDir.getAbsolutePath()) - try { - val logInput = provider.openLegacyEventLog(logPath) - try { - Source.fromInputStream(logInput).getLines().toSeq.size should be (2) - } finally { - logInput.close() - } - } catch { - case e: IllegalArgumentException => - valid should be (false) - } - } - } - test("SPARK-3697: ignore directories that cannot be read.") { val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, @@ -395,21 +344,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc SparkListenerLogStart("1.4") ) - // Write a 1.2 log file with no start event (= no app id), it should be ignored. - writeOldLog("v12Log", "1.2", None, false) - - // Write 1.0 and 1.1 logs, which don't have app ids. - writeOldLog("v11Log", "1.1", None, true, - SparkListenerApplicationStart("v11Log", None, 2L, "test", None), - SparkListenerApplicationEnd(3L)) - writeOldLog("v10Log", "1.0", None, true, - SparkListenerApplicationStart("v10Log", None, 2L, "test", None), - SparkListenerApplicationEnd(4L)) - updateAndCheck(provider) { list => - list.size should be (2) - list(0).id should be ("v10Log") - list(1).id should be ("v11Log") + list.size should be (0) } } @@ -499,25 +435,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) } - private def writeOldLog( - fname: String, - sparkVersion: String, - codec: Option[CompressionCodec], - completed: Boolean, - events: SparkListenerEvent*): File = { - val log = new File(testDir, fname) - log.mkdir() - - val oldEventLog = new File(log, LOG_PREFIX + "1") - createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion)) - writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*) - if (completed) { - createEmptyFile(new File(log, APPLICATION_COMPLETE)) - } - - log - } - private class SafeModeTestProvider(conf: SparkConf, clock: Clock) extends FsHistoryProvider(conf, clock) { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index be55b2e0fe1b7..40d0076eecfc8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -176,18 +176,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers (1 to 2).foreach { attemptId => doDownloadTest("local-1430917381535", Some(attemptId)) } } - test("download legacy logs - all attempts") { - doDownloadTest("local-1426533911241", None, legacy = true) - } - - test("download legacy logs - single attempts") { - (1 to 2). foreach { - attemptId => doDownloadTest("local-1426533911241", Some(attemptId), legacy = true) - } - } - // Test that the files are downloaded correctly, and validate them. - def doDownloadTest(appId: String, attemptId: Option[Int], legacy: Boolean = false): Unit = { + def doDownloadTest(appId: String, attemptId: Option[Int]): Unit = { val url = attemptId match { case Some(id) => @@ -205,22 +195,13 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers var entry = zipStream.getNextEntry entry should not be null val totalFiles = { - if (legacy) { - attemptId.map { x => 3 }.getOrElse(6) - } else { - attemptId.map { x => 1 }.getOrElse(2) - } + attemptId.map { x => 1 }.getOrElse(2) } var filesCompared = 0 while (entry != null) { if (!entry.isDirectory) { val expectedFile = { - if (legacy) { - val splits = entry.getName.split("/") - new File(new File(logDir, splits(0)), splits(1)) - } else { - new File(logDir, entry.getName) - } + new File(logDir, entry.getName) } val expected = Files.toString(expectedFile, Charsets.UTF_8) val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8) From 0fff5c6e6325357a241d311e72db942c4850af34 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 23:08:11 -0800 Subject: [PATCH 046/107] [SPARK-13130][SQL] Make codegen variable names easier to read 1. Use lower case 2. Change long prefixes to something shorter (in this case I am changing only one: TungstenAggregate -> agg). Author: Reynold Xin Closes #11017 from rxin/SPARK-13130. --- .../spark/sql/execution/WholeStageCodegen.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ef81ba60f049f..02b0f423ed438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.util.Utils /** @@ -33,6 +34,12 @@ import org.apache.spark.util.Utils */ trait CodegenSupport extends SparkPlan { + /** Prefix used in the current operator's variable names. */ + private def variablePrefix: String = this match { + case _: TungstenAggregate => "agg" + case _ => nodeName.toLowerCase + } + /** * Whether this SparkPlan support whole stage codegen or not. */ @@ -53,7 +60,7 @@ trait CodegenSupport extends SparkPlan { */ def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent - ctx.freshNamePrefix = nodeName + ctx.freshNamePrefix = variablePrefix doProduce(ctx) } @@ -94,7 +101,7 @@ trait CodegenSupport extends SparkPlan { child: SparkPlan, input: Seq[ExprCode], row: String = null): String = { - ctx.freshNamePrefix = nodeName + ctx.freshNamePrefix = variablePrefix if (row != null) { ctx.currentVars = null ctx.INPUT_ROW = row From b8666fd0e2a797924eb2e94ac5558aba2a9b5140 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 23:37:06 -0800 Subject: [PATCH 047/107] Closes #10662. Closes #10661 From 22ba21348b28d8b1909ccde6fe17fb9e68531e5a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 16:48:59 +0800 Subject: [PATCH 048/107] [SPARK-13087][SQL] Fix group by function for sort based aggregation It is not valid to call `toAttribute` on a `NamedExpression` unless we know for sure that the child produced that `NamedExpression`. The current code worked fine when the grouping expressions were simple, but when they were a derived value this blew up at execution time. Author: Michael Armbrust Closes #11013 from marmbrus/groupByFunction-master. --- .../org/apache/spark/sql/execution/aggregate/utils.scala | 5 ++--- .../spark/sql/hive/execution/AggregationQuerySuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 83379ae90f703..1e113ccd4e137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -33,15 +33,14 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) val completeAggregateAttributes = completeAggregateExpressions.map { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, + requiredChildDistributionExpressions = Some(groupingExpressions), + groupingExpressions = groupingExpressions, aggregateExpressions = completeAggregateExpressions, aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 3e4cf3f79e57c..7a9ed1eaf3dbc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -193,6 +193,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te sqlContext.dropTempTable("emptyTable") } + test("group by function") { + Seq((1, 2)).toDF("a", "b").registerTempTable("data") + + checkAnswer( + sql("SELECT floor(a) AS a, collect_set(b) FROM data GROUP BY floor(a) ORDER BY a"), + Row(1, Array(2)) :: Nil) + } + test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( From 12a20c144f14e80ef120ddcfb0b455a805a2da23 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 10:13:54 -0800 Subject: [PATCH 049/107] [SPARK-10820][SQL] Support for the continuous execution of structured queries This is a follow up to 9aadcffabd226557174f3ff566927f873c71672e that extends Spark SQL to allow users to _repeatedly_ optimize and execute structured queries. A `ContinuousQuery` can be expressed using SQL, DataFrames or Datasets. The purpose of this PR is only to add some initial infrastructure which will be extended in subsequent PRs. ## User-facing API - `sqlContext.streamFrom` and `df.streamTo` return builder objects that are analogous to the `read/write` interfaces already available to executing queries in a batch-oriented fashion. - `ContinuousQuery` provides an interface for interacting with a query that is currently executing in the background. ## Internal Interfaces - `StreamExecution` - executes streaming queries in micro-batches The following are currently internal, but public APIs will be provided in a future release. - `Source` - an interface for providers of continually arriving data. A source must have a notion of an `Offset` that monotonically tracks what data has arrived. For fault tolerance, a source must be able to replay data given a start offset. - `Sink` - an interface that accepts the results of a continuously executing query. Also responsible for tracking the offset that should be resumed from in the case of a failure. ## Testing - `MemoryStream` and `MemorySink` - simple implementations of source and sink that keep all data in memory and have methods for simulating durability failures - `StreamTest` - a framework for performing actions and checking invariants on a continuous query Author: Michael Armbrust Author: Tathagata Das Author: Josh Rosen Closes #11006 from marmbrus/structured-streaming. --- .../apache/spark/sql/ContinuousQuery.scala | 30 ++ .../org/apache/spark/sql/DataFrame.scala | 8 + .../apache/spark/sql/DataStreamReader.scala | 127 +++++++ .../apache/spark/sql/DataStreamWriter.scala | 134 +++++++ .../org/apache/spark/sql/SQLContext.scala | 8 + .../datasources/ResolvedDataSource.scala | 33 +- .../spark/sql/execution/streaming/Batch.scala | 26 ++ .../execution/streaming/CompositeOffset.scala | 67 ++++ .../sql/execution/streaming/LongOffset.scala | 33 ++ .../sql/execution/streaming/Offset.scala | 37 ++ .../spark/sql/execution/streaming/Sink.scala | 47 +++ .../sql/execution/streaming/Source.scala | 36 ++ .../execution/streaming/StreamExecution.scala | 211 +++++++++++ .../execution/streaming/StreamProgress.scala | 67 ++++ .../streaming/StreamingRelation.scala | 34 ++ .../sql/execution/streaming/memory.scala | 138 +++++++ .../apache/spark/sql/sources/interfaces.scala | 21 ++ .../org/apache/spark/sql/QueryTest.scala | 74 ++-- .../org/apache/spark/sql/StreamTest.scala | 346 ++++++++++++++++++ .../sql/streaming/DataStreamReaderSuite.scala | 166 +++++++++ .../streaming/MemorySourceStressSuite.scala | 33 ++ .../spark/sql/streaming/OffsetSuite.scala | 98 +++++ .../spark/sql/streaming/StreamSuite.scala | 84 +++++ .../spark/sql/test/SharedSQLContext.scala | 2 +- 24 files changed, 1828 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala new file mode 100644 index 0000000000000..1c2c0290fc4cd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * A handle to a query that is executing continuously in the background as new data arrives. + */ +trait ContinuousQuery { + + /** + * Stops the execution of this query if it is running. This method blocks until the threads + * performing execution has stopped. + */ + def stop(): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 518f9dcf94a70..6de17e5924d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1690,6 +1690,14 @@ class DataFrame private[sql]( @Experimental def write: DataFrameWriter = new DataFrameWriter(this) + /** + * :: Experimental :: + * Interface for starting a streaming query that will continually output results to the specified + * external sink as new data arrives. + */ + @Experimental + def streamTo: DataStreamWriter = new DataStreamWriter(this) + /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala new file mode 100644 index 0000000000000..2febc93fa49d4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala @@ -0,0 +1,127 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.datasources.ResolvedDataSource +import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * An interface to reading streaming data. Use `sqlContext.streamFrom` to access these methods. + * + * {{{ + * val df = sqlContext.streamFrom + * .format("...") + * .open() + * }}} + */ +@Experimental +class DataStreamReader private[sql](sqlContext: SQLContext) extends Logging { + + /** + * Specifies the input data source format. + * + * @since 2.0.0 + */ + def format(source: String): DataStreamReader = { + this.source = source + this + } + + /** + * Specifies the input schema. Some data streams (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data stream can + * skip the schema inference step, and thus speed up data reading. + * + * @since 2.0.0 + */ + def schema(schema: StructType): DataStreamReader = { + this.userSpecifiedSchema = Option(schema) + this + } + + /** + * Adds an input option for the underlying data stream. + * + * @since 2.0.0 + */ + def option(key: String, value: String): DataStreamReader = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds input options for the underlying data stream. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): DataStreamReader = { + this.extraOptions ++= options + this + } + + /** + * Adds input options for the underlying data stream. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): DataStreamReader = { + this.options(options.asScala) + this + } + + /** + * Loads streaming input in as a [[DataFrame]], for data streams that don't require a path (e.g. + * external key-value stores). + * + * @since 2.0.0 + */ + def open(): DataFrame = { + val resolved = ResolvedDataSource.createSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + providerName = source, + options = extraOptions.toMap) + DataFrame(sqlContext, StreamingRelation(resolved)) + } + + /** + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + def open(path: String): DataFrame = { + option("path", path).open() + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = sqlContext.conf.defaultDataSourceName + + private var userSpecifiedSchema: Option[StructType] = None + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala new file mode 100644 index 0000000000000..b325d48fcbbb1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.datasources.ResolvedDataSource +import org.apache.spark.sql.execution.streaming.StreamExecution + +/** + * :: Experimental :: + * Interface used to start a streaming query query execution. + * + * @since 2.0.0 + */ +@Experimental +final class DataStreamWriter private[sql](df: DataFrame) { + + /** + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 2.0.0 + */ + def format(source: String): DataStreamWriter = { + this.source = source + this + } + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: String): DataStreamWriter = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): DataStreamWriter = { + this.extraOptions ++= options + this + } + + /** + * Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): DataStreamWriter = { + this.options(options.asScala) + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme.\ + * @since 2.0.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataStreamWriter = { + this.partitioningColumns = colNames + this + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * @since 2.0.0 + */ + def start(path: String): ContinuousQuery = { + this.extraOptions += ("path" -> path) + start() + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def start(): ContinuousQuery = { + val sink = ResolvedDataSource.createSink( + df.sqlContext, + source, + extraOptions.toMap, + normalizedParCols) + + new StreamExecution(df.sqlContext, df.logicalPlan, sink) + } + + private def normalizedParCols: Seq[String] = { + partitioningColumns.map { col => + df.logicalPlan.output + .map(_.name) + .find(df.sqlContext.analyzer.resolver(_, col)) + .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + + s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + } + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = df.sqlContext.conf.defaultDataSourceName + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var partitioningColumns: Seq[String] = Nil + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ef993c3edae37..13700be06828d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -594,6 +594,14 @@ class SQLContext private[sql]( @Experimental def read: DataFrameReader = new DataFrameReader(this) + + /** + * :: Experimental :: + * Returns a [[DataStreamReader]] than can be used to access data continuously as it arrives. + */ + @Experimental + def streamFrom: DataStreamReader = new DataStreamReader(this) + /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index cc8dcf59307f2..e3065ac5f87d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -29,11 +29,11 @@ import org.apache.hadoop.util.StringUtils import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{CalendarIntervalType, StructType} import org.apache.spark.util.Utils - case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) @@ -92,6 +92,37 @@ object ResolvedDataSource extends Logging { } } + def createSource( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + providerName: String, + options: Map[String, String]): Source = { + val provider = lookupDataSource(providerName).newInstance() match { + case s: StreamSourceProvider => s + case _ => + throw new UnsupportedOperationException( + s"Data source $providerName does not support streamed reading") + } + + provider.createSource(sqlContext, options, userSpecifiedSchema) + } + + def createSink( + sqlContext: SQLContext, + providerName: String, + options: Map[String, String], + partitionColumns: Seq[String]): Sink = { + val provider = lookupDataSource(providerName).newInstance() match { + case s: StreamSinkProvider => s + case _ => + throw new UnsupportedOperationException( + s"Data source $providerName does not support streamed writing") + } + + provider.createSink(sqlContext, options, partitionColumns) + } + + /** Create a [[ResolvedDataSource]] for reading data in. */ def apply( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala new file mode 100644 index 0000000000000..1f25eb8fc5223 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.DataFrame + +/** + * Used to pass a batch of data through a streaming query execution along with an indication + * of progress in the stream. + */ +class Batch(val end: Offset, val data: DataFrame) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala new file mode 100644 index 0000000000000..d2cb20ef8b819 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.Try + +/** + * An ordered collection of offsets, used to track the progress of processing data from one or more + * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance + * vector clock that must progress linearly forward. + */ +case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { + /** + * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, + * or greater than the specified object. + */ + override def compareTo(other: Offset): Int = other match { + case otherComposite: CompositeOffset if otherComposite.offsets.size == offsets.size => + val comparisons = offsets.zip(otherComposite.offsets).map { + case (Some(a), Some(b)) => a compareTo b + case (None, None) => 0 + case (None, _) => -1 + case (_, None) => 1 + } + val nonZeroSigns = comparisons.map(sign).filter(_ != 0).toSet + nonZeroSigns.size match { + case 0 => 0 // if both empty or only 0s + case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s) + case _ => // there are both 1s and -1s + throw new IllegalArgumentException( + s"Invalid comparison between non-linear histories: $this <=> $other") + } + case _ => + throw new IllegalArgumentException(s"Cannot compare $this <=> $other") + } + + private def sign(num: Int): Int = num match { + case i if i < 0 => -1 + case i if i == 0 => 0 + case i if i > 0 => 1 + } +} + +object CompositeOffset { + /** + * Returns a [[CompositeOffset]] with a variable sequence of offsets. + * `nulls` in the sequence are converted to `None`s. + */ + def fill(offsets: Offset*): CompositeOffset = { + CompositeOffset(offsets.map(Option(_))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala new file mode 100644 index 0000000000000..008195af38b75 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * A simple offset for sources that produce a single linear stream of data. + */ +case class LongOffset(offset: Long) extends Offset { + + override def compareTo(other: Offset): Int = other match { + case l: LongOffset => offset.compareTo(l.offset) + case _ => + throw new IllegalArgumentException(s"Invalid comparison of $getClass with ${other.getClass}") + } + + def +(increment: Long): LongOffset = new LongOffset(offset + increment) + def -(decrement: Long): LongOffset = new LongOffset(offset - decrement) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala new file mode 100644 index 0000000000000..0f5d6445b1e2b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * A offset is a monotonically increasing metric used to track progress in the computation of a + * stream. An [[Offset]] must be comparable, and the result of `compareTo` must be consistent + * with `equals` and `hashcode`. + */ +trait Offset extends Serializable { + + /** + * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, + * or greater than the specified object. + */ + def compareTo(other: Offset): Int + + def >(other: Offset): Boolean = compareTo(other) > 0 + def <(other: Offset): Boolean = compareTo(other) < 0 + def <=(other: Offset): Boolean = compareTo(other) <= 0 + def >=(other: Offset): Boolean = compareTo(other) >= 0 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala new file mode 100644 index 0000000000000..1bd71b6b02ea9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * An interface for systems that can collect the results of a streaming query. + * + * When new data is produced by a query, a [[Sink]] must be able to transactionally collect the + * data and update the [[Offset]]. In the case of a failure, the sink will be recreated + * and must be able to return the [[Offset]] for all of the data that is made durable. + * This contract allows Spark to process data with exactly-once semantics, even in the case + * of failures that require the computation to be restarted. + */ +trait Sink { + /** + * Returns the [[Offset]] for all data that is currently present in the sink, if any. This + * function will be called by Spark when restarting execution in order to determine at which point + * in the input stream computation should be resumed from. + */ + def currentOffset: Option[Offset] + + /** + * Accepts a new batch of data as well as a [[Offset]] that denotes how far in the input + * data computation has progressed to. When computation restarts after a failure, it is important + * that a [[Sink]] returns the same [[Offset]] as the most recent batch of data that + * has been persisted durrably. Note that this does not necessarily have to be the + * [[Offset]] for the most recent batch of data that was given to the sink. For example, + * it is valid to buffer data before persisting, as long as the [[Offset]] is stored + * transactionally as data is eventually persisted. + */ + def addBatch(batch: Batch): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala new file mode 100644 index 0000000000000..25922979ac83e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.types.StructType + +/** + * A source of continually arriving data for a streaming query. A [[Source]] must have a + * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark + * will regularly query each [[Source]] to see if any more data is available. + */ +trait Source { + + /** Returns the schema of the data from this source */ + def schema: StructType + + /** + * Returns the next batch of data that is available after `start`, if any is available. + */ + def getNextBatch(start: Option[Offset]): Option[Batch] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala new file mode 100644 index 0000000000000..ebebb829710b2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.lang.Thread.UncaughtExceptionHandler + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.Logging +import org.apache.spark.sql.{ContinuousQuery, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.QueryExecution + +/** + * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. + * Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any + * [[Source]] present in the query plan. Whenever new data arrives, a [[QueryExecution]] is created + * and the results are committed transactionally to the given [[Sink]]. + */ +class StreamExecution( + sqlContext: SQLContext, + private[sql] val logicalPlan: LogicalPlan, + val sink: Sink) extends ContinuousQuery with Logging { + + /** An monitor used to wait/notify when batches complete. */ + private val awaitBatchLock = new Object + + @volatile + private var batchRun = false + + /** Minimum amount of time in between the start of each batch. */ + private val minBatchTime = 10 + + /** Tracks how much data we have processed from each input source. */ + private[sql] val streamProgress = new StreamProgress + + /** All stream sources present the query plan. */ + private val sources = + logicalPlan.collect { case s: StreamingRelation => s.source } + + // Start the execution at the current offsets stored in the sink. (i.e. avoid reprocessing data + // that we have already processed). + { + sink.currentOffset match { + case Some(c: CompositeOffset) => + val storedProgress = c.offsets + val sources = logicalPlan collect { + case StreamingRelation(source, _) => source + } + + assert(sources.size == storedProgress.size) + sources.zip(storedProgress).foreach { case (source, offset) => + offset.foreach(streamProgress.update(source, _)) + } + case None => // We are starting this stream for the first time. + case _ => throw new IllegalArgumentException("Expected composite offset from sink") + } + } + + logInfo(s"Stream running at $streamProgress") + + /** When false, signals to the microBatchThread that it should stop running. */ + @volatile private var shouldRun = true + + /** The thread that runs the micro-batches of this stream. */ + private[sql] val microBatchThread = new Thread("stream execution thread") { + override def run(): Unit = { + SQLContext.setActive(sqlContext) + while (shouldRun) { + attemptBatch() + Thread.sleep(minBatchTime) // TODO: Could be tighter + } + } + } + microBatchThread.setDaemon(true) + microBatchThread.setUncaughtExceptionHandler( + new UncaughtExceptionHandler { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + streamDeathCause = e + } + }) + microBatchThread.start() + + @volatile + private[sql] var lastExecution: QueryExecution = null + @volatile + private[sql] var streamDeathCause: Throwable = null + + /** + * Checks to see if any new data is present in any of the sources. When new data is available, + * a batch is executed and passed to the sink, updating the currentOffsets. + */ + private def attemptBatch(): Unit = { + val startTime = System.nanoTime() + + // A list of offsets that need to be updated if this batch is successful. + // Populated while walking the tree. + val newOffsets = new ArrayBuffer[(Source, Offset)] + // A list of attributes that will need to be updated. + var replacements = new ArrayBuffer[(Attribute, Attribute)] + // Replace sources in the logical plan with data that has arrived since the last batch. + val withNewSources = logicalPlan transform { + case StreamingRelation(source, output) => + val prevOffset = streamProgress.get(source) + val newBatch = source.getNextBatch(prevOffset) + + newBatch.map { batch => + newOffsets += ((source, batch.end)) + val newPlan = batch.data.logicalPlan + + assert(output.size == newPlan.output.size) + replacements ++= output.zip(newPlan.output) + newPlan + }.getOrElse { + LocalRelation(output) + } + } + + // Rewire the plan to use the new attributes that were returned by the source. + val replacementMap = AttributeMap(replacements) + val newPlan = withNewSources transformAllExpressions { + case a: Attribute if replacementMap.contains(a) => replacementMap(a) + } + + if (newOffsets.nonEmpty) { + val optimizerStart = System.nanoTime() + + lastExecution = new QueryExecution(sqlContext, newPlan) + val executedPlan = lastExecution.executedPlan + val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 + logDebug(s"Optimized batch in ${optimizerTime}ms") + + streamProgress.synchronized { + // Update the offsets and calculate a new composite offset + newOffsets.foreach(streamProgress.update) + val newStreamProgress = logicalPlan.collect { + case StreamingRelation(source, _) => streamProgress.get(source) + } + val batchOffset = CompositeOffset(newStreamProgress) + + // Construct the batch and send it to the sink. + val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan)) + sink.addBatch(nextBatch) + } + + batchRun = true + awaitBatchLock.synchronized { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLock.notifyAll() + } + + val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 + logInfo(s"Compete up to $newOffsets in ${batchTime}ms") + } + + logDebug(s"Waiting for data, current: $streamProgress") + } + + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + def stop(): Unit = { + shouldRun = false + if (microBatchThread.isAlive) { microBatchThread.join() } + } + + /** + * Blocks the current thread until processing for data from the given `source` has reached at + * least the given `Offset`. This method is indented for use primarily when writing tests. + */ + def awaitOffset(source: Source, newOffset: Offset): Unit = { + def notDone = streamProgress.synchronized { + !streamProgress.contains(source) || streamProgress(source) < newOffset + } + + while (notDone) { + logInfo(s"Waiting until $newOffset at $source") + awaitBatchLock.synchronized { awaitBatchLock.wait(100) } + } + logDebug(s"Unblocked at $newOffset for $source") + } + + override def toString: String = + s""" + |=== Streaming Query === + |CurrentOffsets: $streamProgress + |Thread State: ${microBatchThread.getState} + |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + | + |$logicalPlan + """.stripMargin +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala new file mode 100644 index 0000000000000..0ded1d7152c19 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +/** + * A helper class that looks like a Map[Source, Offset]. + */ +class StreamProgress { + private val currentOffsets = new mutable.HashMap[Source, Offset] + + private[streaming] def update(source: Source, newOffset: Offset): Unit = { + currentOffsets.get(source).foreach(old => + assert(newOffset > old, s"Stream going backwards $newOffset -> $old")) + currentOffsets.put(source, newOffset) + } + + private[streaming] def update(newOffset: (Source, Offset)): Unit = + update(newOffset._1, newOffset._2) + + private[streaming] def apply(source: Source): Offset = currentOffsets(source) + private[streaming] def get(source: Source): Option[Offset] = currentOffsets.get(source) + private[streaming] def contains(source: Source): Boolean = currentOffsets.contains(source) + + private[streaming] def ++(updates: Map[Source, Offset]): StreamProgress = { + val updated = new StreamProgress + currentOffsets.foreach(updated.update) + updates.foreach(updated.update) + updated + } + + /** + * Used to create a new copy of this [[StreamProgress]]. While this class is currently mutable, + * it should be copied before being passed to user code. + */ + private[streaming] def copy(): StreamProgress = { + val copied = new StreamProgress + currentOffsets.foreach(copied.update) + copied + } + + override def toString: String = + currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") + + override def equals(other: Any): Boolean = other match { + case s: StreamProgress => currentOffsets == s.currentOffsets + case _ => false + } + + override def hashCode: Int = currentOffsets.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala new file mode 100644 index 0000000000000..e35c444348f48 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LeafNode + +object StreamingRelation { + def apply(source: Source): StreamingRelation = + StreamingRelation(source, source.schema.toAttributes) +} + +/** + * Used to link a streaming [[Source]] of data into a + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. + */ +case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode { + override def toString: String = source.toString +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala new file mode 100644 index 0000000000000..e6a0842936ea2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder} +import org.apache.spark.sql.types.StructType + +object MemoryStream { + protected val currentBlockId = new AtomicInteger(0) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] + * is primarily intended for use in unit tests as it can only replay data when the object is still + * available. + */ +case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends Source with Logging { + protected val encoder = encoderFor[A] + protected val logicalPlan = StreamingRelation(this) + protected val output = logicalPlan.output + protected val batches = new ArrayBuffer[Dataset[A]] + protected var currentOffset: LongOffset = new LongOffset(-1) + + protected def blockManager = SparkEnv.get.blockManager + + def schema: StructType = encoder.schema + + def getCurrentOffset: Offset = currentOffset + + def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { + new Dataset(sqlContext, logicalPlan) + } + + def toDF()(implicit sqlContext: SQLContext): DataFrame = { + new DataFrame(sqlContext, logicalPlan) + } + + def addData(data: TraversableOnce[A]): Offset = { + import sqlContext.implicits._ + this.synchronized { + currentOffset = currentOffset + 1 + val ds = data.toVector.toDS() + logDebug(s"Adding ds: $ds") + batches.append(ds) + currentOffset + } + } + + override def getNextBatch(start: Option[Offset]): Option[Batch] = synchronized { + val newBlocks = + batches.drop( + start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1) + + if (newBlocks.nonEmpty) { + logDebug(s"Running [$start, $currentOffset] on blocks ${newBlocks.mkString(", ")}") + val df = newBlocks + .map(_.toDF()) + .reduceOption(_ unionAll _) + .getOrElse(sqlContext.emptyDataFrame) + + Some(new Batch(currentOffset, df)) + } else { + None + } + } + + override def toString: String = s"MemoryStream[${output.mkString(",")}]" +} + +/** + * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit + * tests and does not provide durablility. + */ +class MemorySink(schema: StructType) extends Sink with Logging { + /** An order list of batches that have been written to this [[Sink]]. */ + private var batches = new ArrayBuffer[Batch]() + + /** Used to convert an [[InternalRow]] to an external [[Row]] for comparison in testing. */ + private val externalRowConverter = RowEncoder(schema) + + override def currentOffset: Option[Offset] = synchronized { + batches.lastOption.map(_.end) + } + + override def addBatch(nextBatch: Batch): Unit = synchronized { + batches.append(nextBatch) + } + + /** Returns all rows that are stored in this [[Sink]]. */ + def allData: Seq[Row] = synchronized { + batches + .map(_.data) + .reduceOption(_ unionAll _) + .map(_.collect().toSeq) + .getOrElse(Seq.empty) + } + + /** + * Atomically drops the most recent `num` batches and resets the [[StreamProgress]] to the + * corresponding point in the input. This function can be used when testing to simulate data + * that has been lost due to buffering. + */ + def dropBatches(num: Int): Unit = synchronized { + batches.dropRight(num) + } + + override def toString: String = synchronized { + batches.map(b => s"${b.end}: ${b.data.collect().mkString(" ")}").mkString("\n") + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8911ad370aa7b..299fc6efbb046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -123,6 +124,26 @@ trait SchemaRelationProvider { schema: StructType): BaseRelation } +/** + * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. + */ +trait StreamSourceProvider { + def createSource( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): Source +} + +/** + * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. + */ +trait StreamSinkProvider { + def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink +} + /** * ::Experimental:: * Implemented by objects that produce relations for a specific kind of data source diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ce12f788b786c..405e5891ac976 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -304,27 +304,7 @@ object QueryTest { def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { - Row.fromSeq(row.toSeq.map { - case null => null - case d: java.math.BigDecimal => BigDecimal(d) - // Convert array to Seq for easy equality check. - case b: Array[_] => b.toSeq - case r: Row => prepareRow(r) - case o => o - }) - } - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - val converted: Seq[Row] = answer.map(prepareRow) - if (!isSorted) converted.sortBy(_.toString()) else converted - } val sparkAnswer = try df.collect().toSeq catch { case e: Exception => val errorMessage = @@ -338,22 +318,56 @@ object QueryTest { return Some(errorMessage) } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = + sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => s""" |Results do not match for query: |${df.queryExecution} |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) + |$results + """.stripMargin } + } + + + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } - return None + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + val errorMessage = + s""" + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + None } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala new file mode 100644 index 0000000000000..f45abbf2496a2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.Thread.UncaughtExceptionHandler + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.streaming._ + +/** + * A framework for implementing tests for streaming queries and sources. + * + * A test consists of a set of steps (expressed as a `StreamAction`) that are executed in order, + * blocking as necessary to let the stream catch up. For example, the following adds some data to + * a stream, blocking until it can verify that the correct values are eventually produced. + * + * {{{ + * val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(_ + 1) + + testStream(mapped)( + AddData(inputData, 1, 2, 3), + CheckAnswer(2, 3, 4)) + * }}} + * + * Note that while we do sleep to allow the other thread to progress without spinning, + * `StreamAction` checks should not depend on the amount of time spent sleeping. Instead they + * should check the actual progress of the stream before verifying the required test condition. + * + * Currently it is assumed that all streaming queries will eventually complete in 10 seconds to + * avoid hanging forever in the case of failures. However, individual suites can change this + * by overriding `streamingTimeout`. + */ +trait StreamTest extends QueryTest with Timeouts { + + implicit class RichSource(s: Source) { + def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s)) + } + + /** How long to wait for an active stream to catch up when checking a result. */ + val streamingTimout = 10.seconds + + /** A trait for actions that can be performed while testing a streaming DataFrame. */ + trait StreamAction + + /** A trait to mark actions that require the stream to be actively running. */ + trait StreamMustBeRunning + + /** + * Adds the given data to the stream. Subsuquent check answers will block until this data has + * been processed. + */ + object AddData { + def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + AddDataMemory(source, data) + } + + /** A trait that can be extended when testing other sources. */ + trait AddData extends StreamAction { + def source: Source + + /** + * Called to trigger adding the data. Should return the offset that will denote when this + * new data has been processed. + */ + def addData(): Offset + } + + case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + override def toString: String = s"AddData to $source: ${data.mkString(",")}" + + override def addData(): Offset = { + source.addData(data) + } + } + + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks untill all added data has been processed. + */ + object CheckAnswer { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" + } + + case class DropBatches(num: Int) extends StreamAction + + /** Stops the stream. It must currently be running. */ + case object StopStream extends StreamAction with StreamMustBeRunning + + /** Starts the stream, resuming if data has already been processed. It must not be running. */ + case object StartStream extends StreamAction + + /** Signals that a failure is expected and should not kill the test. */ + case object ExpectFailure extends StreamAction + + /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */ + def testStream(stream: Dataset[_])(actions: StreamAction*): Unit = + testStream(stream.toDF())(actions: _*) + + /** + * Executes the specified actions on the the given streaming DataFrame and provides helpful + * error messages in the case of failures or incorrect answers. + * + * Note that if the stream is not explictly started before an action that requires it to be + * running then it will be automatically started before performing any other actions. + */ + def testStream(stream: DataFrame)(actions: StreamAction*): Unit = { + var pos = 0 + var currentPlan: LogicalPlan = stream.logicalPlan + var currentStream: StreamExecution = null + val awaiting = new mutable.HashMap[Source, Offset]() + val sink = new MemorySink(stream.schema) + + @volatile + var streamDeathCause: Throwable = null + + // If the test doesn't manually start the stream, we do it automatically at the beginning. + val startedManually = + actions.takeWhile(!_.isInstanceOf[StreamMustBeRunning]).contains(StartStream) + val startedTest = if (startedManually) actions else StartStream +: actions + + def testActions = actions.zipWithIndex.map { + case (a, i) => + if ((pos == i && startedManually) || (pos == (i + 1) && !startedManually)) { + "=> " + a.toString + } else { + " " + a.toString + } + }.mkString("\n") + + def currentOffsets = + if (currentStream != null) currentStream.streamProgress.toString else "not started" + + def threadState = + if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def testState = + s""" + |== Progress == + |$testActions + | + |== Stream == + |Stream state: $currentOffsets + |Thread state: $threadState + |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + | + |== Sink == + |$sink + | + |== Plan == + |${if (currentStream != null) currentStream.lastExecution else ""} + """ + + def checkState(check: Boolean, error: String) = if (!check) { + fail( + s""" + |Invalid State: $error + |$testState + """.stripMargin) + } + + val testThread = Thread.currentThread() + + try { + startedTest.foreach { action => + action match { + case StartStream => + checkState(currentStream == null, "stream already running") + + currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink) + currentStream.microBatchThread.setUncaughtExceptionHandler( + new UncaughtExceptionHandler { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + streamDeathCause = e + testThread.interrupt() + } + }) + + case StopStream => + checkState(currentStream != null, "can not stop a stream that is not running") + currentStream.stop() + currentStream = null + + case DropBatches(num) => + checkState(currentStream == null, "dropping batches while running leads to corruption") + sink.dropBatches(num) + + case ExpectFailure => + try failAfter(streamingTimout) { + while (streamDeathCause == null) { + Thread.sleep(100) + } + } catch { + case _: InterruptedException => + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + fail( + s""" + |Timed out while waiting for failure. + |$testState + """.stripMargin) + } + + currentStream = null + streamDeathCause = null + + case a: AddData => + awaiting.put(a.source, a.addData()) + + case CheckAnswerRows(expectedAnswer) => + checkState(currentStream != null, "stream not running") + + // Block until all data added has been processed + awaiting.foreach { case (source, offset) => + failAfter(streamingTimout) { + currentStream.awaitOffset(source, offset) + } + } + + val allData = try sink.allData catch { + case e: Exception => + fail( + s""" + |Exception while getting data from sink $e + |$testState + """.stripMargin) + } + + QueryTest.sameRows(expectedAnswer, allData).foreach { + error => fail( + s""" + |$error + |$testState + """.stripMargin) + } + } + pos += 1 + } + } catch { + case _: InterruptedException if streamDeathCause != null => + fail( + s""" + |Stream Thread Died + |$testState + """.stripMargin) + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + fail( + s""" + |Timed out waiting for stream + |$testState + """.stripMargin) + } finally { + if (currentStream != null && currentStream.microBatchThread.isAlive) { + currentStream.stop() + } + } + } + + /** + * Creates a stress test that randomly starts/stops/adds data/checks the result. + * + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result. + * @param addData and add data action that adds the given numbers to the stream, encoding them + * as needed + */ + def runStressTest( + ds: Dataset[Int], + addData: Seq[Int] => StreamAction, + iterations: Int = 100): Unit = { + implicit val intEncoder = ExpressionEncoder[Int]() + var dataPos = 0 + var running = true + val actions = new ArrayBuffer[StreamAction]() + + def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) } + + def addRandomData() = { + val numItems = Random.nextInt(10) + val data = dataPos until (dataPos + numItems) + dataPos += numItems + actions += addData(data) + } + + (1 to iterations).foreach { i => + val rand = Random.nextDouble() + if(!running) { + rand match { + case r if r < 0.7 => // AddData + addRandomData() + + case _ => // StartStream + actions += StartStream + running = true + } + } else { + rand match { + case r if r < 0.1 => + addCheck() + + case r if r < 0.7 => // AddData + addRandomData() + + case _ => // StartStream + actions += StopStream + running = false + } + } + } + if(!running) { actions += StartStream } + addCheck() + testStream(ds)(actions: _*) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala new file mode 100644 index 0000000000000..1dab6ebf1bee9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.test + +import org.apache.spark.sql.{AnalysisException, SQLContext, StreamTest} +import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source} +import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +object LastOptions { + var parameters: Map[String, String] = null + var schema: Option[StructType] = null + var partitionColumns: Seq[String] = Nil +} + +/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */ +class DefaultSource extends StreamSourceProvider with StreamSinkProvider { + override def createSource( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): Source = { + LastOptions.parameters = parameters + LastOptions.schema = schema + new Source { + override def getNextBatch(start: Option[Offset]): Option[Batch] = None + override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + } + } + + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink = { + LastOptions.parameters = parameters + LastOptions.partitionColumns = partitionColumns + new Sink { + override def addBatch(batch: Batch): Unit = {} + override def currentOffset: Option[Offset] = None + } + } +} + +class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("resolve default source") { + sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open() + .streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + } + + test("resolve full class") { + sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test.DefaultSource") + .open() + .streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + } + + test("options") { + val map = new java.util.HashMap[String, String] + map.put("opt3", "3") + + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .open() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + + LastOptions.parameters = null + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .start() + .stop() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + } + + test("partitioning") { + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open() + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + assert(LastOptions.partitionColumns == Nil) + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("a") + .start() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + + + withSQLConf("spark.sql.caseSensitive" -> "false") { + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("A") + .start() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + } + + intercept[AnalysisException] { + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("b") + .start() + .stop() + } + } + + test("stream paths") { + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open("/test") + + assert(LastOptions.parameters("path") == "/test") + + LastOptions.parameters = null + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .start("/test") + .stop() + + assert(LastOptions.parameters("path") == "/test") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala new file mode 100644 index 0000000000000..81760d2aa8205 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext + +class MemorySourceStressSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("memory stress test") { + val input = MemoryStream[Int] + val mapped = input.toDS().map(_ + 1) + + runStressTest(mapped, AddData(input, _: _*)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala new file mode 100644 index 0000000000000..989465826d54e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, Offset} + +trait OffsetSuite extends SparkFunSuite { + /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ + def compare(one: Offset, two: Offset): Unit = { + test(s"comparision $one <=> $two") { + assert(one < two) + assert(one <= two) + assert(one <= one) + assert(two > one) + assert(two >= one) + assert(one >= one) + assert(one == one) + assert(two == two) + assert(one != two) + assert(two != one) + } + } + + /** Creates test to check that non-equality comparisons throw exception. */ + def compareInvalid(one: Offset, two: Offset): Unit = { + test(s"invalid comparison $one <=> $two") { + intercept[IllegalArgumentException] { + assert(one < two) + } + + intercept[IllegalArgumentException] { + assert(one <= two) + } + + intercept[IllegalArgumentException] { + assert(one > two) + } + + intercept[IllegalArgumentException] { + assert(one >= two) + } + + assert(!(one == two)) + assert(!(two == one)) + assert(one != two) + assert(two != one) + } + } +} + +class LongOffsetSuite extends OffsetSuite { + val one = LongOffset(1) + val two = LongOffset(2) + compare(one, two) +} + +class CompositeOffsetSuite extends OffsetSuite { + compare( + one = CompositeOffset(Some(LongOffset(1)) :: Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compare( + one = CompositeOffset(None :: Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compareInvalid( // sizes must be same + one = CompositeOffset(Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compare( + one = CompositeOffset.fill(LongOffset(0), LongOffset(1)), + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) + + compare( + one = CompositeOffset.fill(LongOffset(1), LongOffset(1)), + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) + + compareInvalid( + one = CompositeOffset.fill(LongOffset(2), LongOffset(1)), // vector time inconsistent + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala new file mode 100644 index 0000000000000..fbb1792596b18 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.{Row, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext + +class StreamSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("map with recovery") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(_ + 1) + + testStream(mapped)( + AddData(inputData, 1, 2, 3), + StartStream, + CheckAnswer(2, 3, 4), + StopStream, + AddData(inputData, 4, 5, 6), + StartStream, + CheckAnswer(2, 3, 4, 5, 6, 7)) + } + + test("join") { + // Make a table and ensure it will be broadcast. + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val inputData = MemoryStream[Int] + val joined = inputData.toDS().toDF().join(smallTable, $"value" === $"number") + + testStream(joined)( + AddData(inputData, 1, 2, 3), + CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two")), + AddData(inputData, 4), + CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) + } + + test("union two streams") { + val inputData1 = MemoryStream[Int] + val inputData2 = MemoryStream[Int] + + val unioned = inputData1.toDS().union(inputData2.toDS()) + + testStream(unioned)( + AddData(inputData1, 1, 3, 5), + CheckAnswer(1, 3, 5), + AddData(inputData2, 2, 4, 6), + CheckAnswer(1, 2, 3, 4, 5, 6), + StopStream, + AddData(inputData1, 7), + StartStream, + AddData(inputData2, 8), + CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8)) + } + + test("sql queries") { + val inputData = MemoryStream[Int] + inputData.toDF().registerTempTable("stream") + val evens = sql("SELECT * FROM stream WHERE value % 2 = 0") + + testStream(evens)( + AddData(inputData, 1, 2, 3, 4), + CheckAnswer(2, 4)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e7b376548787c..c341191c70bb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -36,7 +36,7 @@ trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected def sqlContext: SQLContext = _ctx + protected implicit def sqlContext: SQLContext = _ctx /** * Initialize the [[TestSQLContext]]. From 29d92181d0c49988c387d34e4a71b1afe02c29e2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 10:15:40 -0800 Subject: [PATCH 050/107] [SPARK-13094][SQL] Add encoders for seq/array of primitives Author: Michael Armbrust Closes #11014 from marmbrus/seqEncoders. --- .../org/apache/spark/sql/SQLImplicits.scala | 63 ++++++++++++++++++- .../spark/sql/DatasetPrimitiveSuite.scala | 22 +++++++ .../org/apache/spark/sql/QueryTest.scala | 8 ++- 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index ab414799f1a42..16c4095db722a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -39,6 +39,8 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + // Primitives + /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() @@ -56,13 +58,72 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() - /** @since 1.6.0 */ + /** @since 1.6.0 */ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() + // Seqs + + /** @since 1.6.1 */ + implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + + // Arrays + + /** @since 1.6.1 */ + implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] = + ExpressionEncoder() + /** * Creates a [[Dataset]] from an RDD. * @since 1.6.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index f75d0961823c4..243d13b19d6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { agged, "1", "abc", "3", "xyz", "5", "hello") } + + test("Arrays and Lists") { + checkAnswer(Seq(Seq(1)).toDS(), Seq(1)) + checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) + checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) + checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) + checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) + checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) + checkAnswer(Seq(Seq(true)).toDS(), Seq(true)) + checkAnswer(Seq(Seq("test")).toDS(), Seq("test")) + checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) + + checkAnswer(Seq(Array(1)).toDS(), Array(1)) + checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) + checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) + checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) + checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) + checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) + checkAnswer(Seq(Array(true)).toDS(), Array(true)) + checkAnswer(Seq(Array("test")).toDS(), Array("test")) + checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 405e5891ac976..5401212428d6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest { """.stripMargin, e) } - if (decoded != expectedAnswer.toSet) { + // Handle the case where the return type is an array + val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false) + def normalEquality = decoded == expectedAnswer.toSet + def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet + def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq) + + if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) { val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted From b93830126cc59a26e2cfb5d7b3c17f9cfbf85988 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 2 Feb 2016 10:41:06 -0800 Subject: [PATCH 051/107] [SPARK-13114][SQL] Add a test for tokens more than the fields in schema https://issues.apache.org/jira/browse/SPARK-13114 This PR adds a test for tokens more than the fields in schema. Author: hyukjinkwon Closes #11020 from HyukjinKwon/SPARK-13114. --- sql/core/src/test/resources/cars-malformed.csv | 6 ++++++ .../sql/execution/datasources/csv/CSVSuite.scala | 12 ++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 sql/core/src/test/resources/cars-malformed.csv diff --git a/sql/core/src/test/resources/cars-malformed.csv b/sql/core/src/test/resources/cars-malformed.csv new file mode 100644 index 0000000000000..cfa378c01f1d9 --- /dev/null +++ b/sql/core/src/test/resources/cars-malformed.csv @@ -0,0 +1,6 @@ +~ All the rows here are malformed having tokens more than the schema (header). +year,make,model,comment,blank +"2012","Tesla","S","No comment",,null,null + +1997,Ford,E350,"Go get one now they are going fast",,null,null +2015,Chevy,,,, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a79566b1f3658..fa4f137b703b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.types._ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val carsFile = "cars.csv" + private val carsMalformedFile = "cars-malformed.csv" private val carsFile8859 = "cars_iso-8859-1.csv" private val carsTsvFile = "cars.tsv" private val carsAltFile = "cars-alternative.csv" @@ -191,6 +192,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) } + test("test for tokens more than the fields in the schema") { + val cars = sqlContext + .read + .format("csv") + .option("header", "false") + .option("comment", "~") + .load(testFile(carsMalformedFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + test("test with null quote character") { val cars = sqlContext.read .format("csv") From cba1d6b659288bfcd8db83a6d778155bab2bbecf Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 2 Feb 2016 10:50:22 -0800 Subject: [PATCH 052/107] [SPARK-12631][PYSPARK][DOC] PySpark clustering parameter desc to consistent format Part of task for [SPARK-11219](https://issues.apache.org/jira/browse/SPARK-11219) to make PySpark MLlib parameter description formatting consistent. This is for the clustering module. Author: Bryan Cutler Closes #10610 from BryanCutler/param-desc-consistent-cluster-SPARK-12631. --- .../mllib/clustering/GaussianMixture.scala | 12 +- .../spark/mllib/clustering/KMeans.scala | 31 +- .../apache/spark/mllib/clustering/LDA.scala | 13 +- .../clustering/PowerIterationClustering.scala | 4 +- .../mllib/clustering/StreamingKMeans.scala | 6 +- python/pyspark/mllib/clustering.py | 265 +++++++++++++----- 6 files changed, 228 insertions(+), 103 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 7b203e2f40815..88dbfe3fcc9f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -45,10 +45,10 @@ import org.apache.spark.util.Utils * This is due to high-dimensional data (a) making it difficult to cluster at all (based * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. * - * @param k The number of independent Gaussians in the mixture model - * @param convergenceTol The maximum change in log-likelihood at which convergence - * is considered to have occurred. - * @param maxIterations The maximum number of iterations to perform + * @param k Number of independent Gaussians in the mixture model. + * @param convergenceTol Maximum change in log-likelihood at which convergence + * is considered to have occurred. + * @param maxIterations Maximum number of iterations allowed. */ @Since("1.3.0") class GaussianMixture private ( @@ -108,7 +108,7 @@ class GaussianMixture private ( def getK: Int = k /** - * Set the maximum number of iterations to run. Default: 100 + * Set the maximum number of iterations allowed. Default: 100 */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { @@ -117,7 +117,7 @@ class GaussianMixture private ( } /** - * Return the maximum number of iterations to run + * Return the maximum number of iterations allowed */ @Since("1.3.0") def getMaxIterations: Int = maxIterations diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ca11ede4ccd47..901164a391170 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -70,13 +70,13 @@ class KMeans private ( } /** - * Maximum number of iterations to run. + * Maximum number of iterations allowed. */ @Since("1.4.0") def getMaxIterations: Int = maxIterations /** - * Set maximum number of iterations to run. Default: 20. + * Set maximum number of iterations allowed. Default: 20. */ @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { @@ -482,12 +482,15 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Vector]` - * @param k number of clusters - * @param maxIterations max number of iterations - * @param runs number of parallel runs, defaults to 1. The best model is returned. - * @param initializationMode initialization model, either "random" or "k-means||" (default). - * @param seed random seed value for cluster initialization + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param runs Number of runs to execute in parallel. The best model according to the cost + * function will be returned. (default: 1) + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") + * @param seed Random seed for cluster initialization. Default is to generate seed based + * on system time. */ @Since("1.3.0") def train( @@ -508,11 +511,13 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Vector]` - * @param k number of clusters - * @param maxIterations max number of iterations - * @param runs number of parallel runs, defaults to 1. The best model is returned. - * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param runs Number of runs to execute in parallel. The best model according to the cost + * function will be returned. (default: 1) + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") */ @Since("0.8.0") def train( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index eb802a365ed6e..81566b4779d66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -61,14 +61,13 @@ class LDA private ( ldaOptimizer = new EMLDAOptimizer) /** - * Number of topics to infer. I.e., the number of soft cluster centers. - * + * Number of topics to infer, i.e., the number of soft cluster centers. */ @Since("1.3.0") def getK: Int = k /** - * Number of topics to infer. I.e., the number of soft cluster centers. + * Set the number of topics to infer, i.e., the number of soft cluster centers. * (default = 10) */ @Since("1.3.0") @@ -222,13 +221,13 @@ class LDA private ( def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** - * Maximum number of iterations for learning. + * Maximum number of iterations allowed. */ @Since("1.3.0") def getMaxIterations: Int = maxIterations /** - * Maximum number of iterations for learning. + * Set the maximum number of iterations allowed. * (default = 20) */ @Since("1.3.0") @@ -238,13 +237,13 @@ class LDA private ( } /** - * Random seed + * Random seed for cluster initialization. */ @Since("1.3.0") def getSeed: Long = seed /** - * Random seed + * Set the random seed for cluster initialization. */ @Since("1.3.0") def setSeed(seed: Long): this.type = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 2ab0920b06363..1ab7cb393b081 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -111,7 +111,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. - * @param initMode Initialization mode. + * @param initMode Set the initialization mode. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use normalized sum similarities. + * Default: random. * * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 79d217e183c62..d99b89dc49ebf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -183,7 +183,7 @@ class StreamingKMeans @Since("1.2.0") ( } /** - * Set the decay factor directly (for forgetful algorithms). + * Set the forgetfulness of the previous centroids. */ @Since("1.2.0") def setDecayFactor(a: Double): this.type = { @@ -192,7 +192,9 @@ class StreamingKMeans @Since("1.2.0") ( } /** - * Set the half life and time unit ("batches" or "points") for forgetful algorithms. + * Set the half life and time unit ("batches" or "points"). If points, then the decay factor + * is raised to the power of number of new points and if batches, then decay factor will be + * used as is. */ @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 4e9eb96fd9da1..ad04e46e8870b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -88,8 +88,11 @@ def predict(self, x): Find the cluster that each of the points belongs to in this model. - :param x: the point (or RDD of points) to determine - compute the clusters for. + :param x: + A data point (or RDD of points) to determine cluster index. + :return: + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. """ if isinstance(x, RDD): vecs = x.map(_convert_to_vector) @@ -105,7 +108,8 @@ def computeCost(self, x): points to their nearest center) for this model on the given data. If provided with an RDD of points returns the sum. - :param point: the point or RDD of points to compute the cost(s). + :param point: + A data point (or RDD of points) to compute the cost(s). """ if isinstance(x, RDD): vecs = x.map(_convert_to_vector) @@ -143,17 +147,23 @@ def train(self, rdd, k=4, maxIterations=20, minDivisibleClusterSize=1.0, seed=-1 """ Runs the bisecting k-means algorithm return the model. - :param rdd: input RDD to be trained on - :param k: The desired number of leaf clusters (default: 4). - The actual number could be smaller if there are no divisible - leaf clusters. - :param maxIterations: the max number of k-means iterations to - split clusters (default: 20) - :param minDivisibleClusterSize: the minimum number of points - (if >= 1.0) or the minimum proportion of points (if < 1.0) - of a divisible cluster (default: 1) - :param seed: a random seed (default: -1888008604 from - classOf[BisectingKMeans].getName.##) + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + The desired number of leaf clusters. The actual number could + be smaller if there are no divisible leaf clusters. + (default: 4) + :param maxIterations: + Maximum number of iterations allowed to split clusters. + (default: 20) + :param minDivisibleClusterSize: + Minimum number of points (if >= 1.0) or the minimum proportion + of points (if < 1.0) of a divisible cluster. + (default: 1) + :param seed: + Random seed value for cluster initialization. + (default: -1888008604 from classOf[BisectingKMeans].getName.##) """ java_model = callMLlibFunc( "trainBisectingKMeans", rdd.map(_convert_to_vector), @@ -239,8 +249,11 @@ def predict(self, x): Find the cluster that each of the points belongs to in this model. - :param x: the point (or RDD of points) to determine - compute the clusters for. + :param x: + A data point (or RDD of points) to determine cluster index. + :return: + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. """ best = 0 best_distance = float("inf") @@ -262,7 +275,8 @@ def computeCost(self, rdd): their nearest center) for this model on the given data. - :param point: the RDD of points to compute the cost on. + :param rdd: + The RDD of points to compute the cost on. """ cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector), [_convert_to_vector(c) for c in self.centers]) @@ -296,7 +310,44 @@ class KMeans(object): @since('0.9.0') def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None): - """Train a k-means clustering model.""" + """ + Train a k-means clustering model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + Number of clusters to create. + :param maxIterations: + Maximum number of iterations allowed. + (default: 100) + :param runs: + Number of runs to execute in parallel. The best model according + to the cost function will be returned (deprecated in 1.6.0). + (default: 1) + :param initializationMode: + The initialization algorithm. This can be either "random" or + "k-means||". + (default: "k-means||") + :param seed: + Random seed value for cluster initialization. Set as None to + generate seed based on system time. + (default: None) + :param initializationSteps: + Number of steps for the k-means|| initialization mode. + This is an advanced setting -- the default of 5 is almost + always enough. + (default: 5) + :param epsilon: + Distance threshold within which a center will be considered to + have converged. If all centers move less than this Euclidean + distance, iterations are stopped. + (default: 1e-4) + :param initialModel: + Initial cluster centers can be provided as a KMeansModel object + rather than using the random or k-means|| initializationModel. + (default: None) + """ if runs != 1: warnings.warn( "Support for runs is deprecated in 1.6.0. This param will have no effect in 2.0.0.") @@ -415,8 +466,11 @@ def predict(self, x): Find the cluster to which the point 'x' or each point in RDD 'x' has maximum membership in this model. - :param x: vector or RDD of vector represents data points. - :return: cluster label or RDD of cluster labels. + :param x: + A feature vector or an RDD of vectors representing data points. + :return: + Predicted cluster label or an RDD of predicted cluster labels + if the input is an RDD. """ if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) @@ -430,9 +484,11 @@ def predictSoft(self, x): """ Find the membership of point 'x' or each point in RDD 'x' to all mixture components. - :param x: vector or RDD of vector represents data points. - :return: the membership value to all mixture components for vector 'x' - or each vector in RDD 'x'. + :param x: + A feature vector or an RDD of vectors representing data points. + :return: + The membership value to all mixture components for vector 'x' + or each vector in RDD 'x'. """ if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) @@ -447,8 +503,10 @@ def predictSoft(self, x): def load(cls, sc, path): """Load the GaussianMixtureModel from disk. - :param sc: SparkContext - :param path: str, path to where the model is stored. + :param sc: + SparkContext. + :param path: + Path to where the model is stored. """ model = cls._load_java(sc, path) wrapper = sc._jvm.GaussianMixtureModelWrapper(model) @@ -461,19 +519,35 @@ class GaussianMixture(object): Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. - :param data: RDD of data points - :param k: Number of components - :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 - :param maxIterations: Number of iterations. Default to 100 - :param seed: Random Seed - :param initialModel: GaussianMixtureModel for initializing learning - .. versionadded:: 1.3.0 """ @classmethod @since('1.3.0') def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): - """Train a Gaussian Mixture clustering model.""" + """ + Train a Gaussian Mixture clustering model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + Number of independent Gaussians in the mixture model. + :param convergenceTol: + Maximum change in log-likelihood at which convergence is + considered to have occurred. + (default: 1e-3) + :param maxIterations: + Maximum number of iterations allowed. + (default: 100) + :param seed: + Random seed for initial Gaussian distribution. Set as None to + generate seed based on system time. + (default: None) + :param initialModel: + Initial GMM starting point, bypassing the random + initialization. + (default: None) + """ initialModelWeights = None initialModelMu = None initialModelSigma = None @@ -574,18 +648,24 @@ class PowerIterationClustering(object): @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): """ - :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the - affinity matrix, which is the matrix A in the PIC paper. - The similarity s,,ij,, must be nonnegative. - This is a symmetric matrix and hence s,,ij,, = s,,ji,,. - For any (i, j) with nonzero similarity, there should be - either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. - Tuples with i = j are ignored, because we assume - s,,ij,, = 0.0. - :param k: Number of clusters. - :param maxIterations: Maximum number of iterations of the - PIC algorithm. - :param initMode: Initialization mode. + :param rdd: + An RDD of (i, j, s\ :sub:`ij`\) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. The + similarity s\ :sub:`ij`\ must be nonnegative. This is a symmetric + matrix and hence s\ :sub:`ij`\ = s\ :sub:`ji`\ For any (i, j) with + nonzero similarity, there should be either (i, j, s\ :sub:`ij`\) or + (j, i, s\ :sub:`ji`\) in the input. Tuples with i = j are ignored, + because it is assumed s\ :sub:`ij`\ = 0.0. + :param k: + Number of clusters. + :param maxIterations: + Maximum number of iterations of the PIC algorithm. + (default: 100) + :param initMode: + Initialization mode. This can be either "random" to use + a random vector as vertex properties, or "degree" to use + normalized sum similarities. + (default: "random") """ model = callMLlibFunc("trainPowerIterationClusteringModel", rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) @@ -625,8 +705,10 @@ class StreamingKMeansModel(KMeansModel): and new data. If it set to zero, the old centroids are completely forgotten. - :param clusterCenters: Initial cluster centers. - :param clusterWeights: List of weights assigned to each cluster. + :param clusterCenters: + Initial cluster centers. + :param clusterWeights: + List of weights assigned to each cluster. >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] >>> initWeights = [1.0, 1.0] @@ -673,11 +755,14 @@ def clusterWeights(self): def update(self, data, decayFactor, timeUnit): """Update the centroids, according to data - :param data: Should be a RDD that represents the new data. - :param decayFactor: forgetfulness of the previous centroids. - :param timeUnit: Can be "batches" or "points". If points, then the - decay factor is raised to the power of number of new - points and if batches, it is used as it is. + :param data: + RDD with new data for the model update. + :param decayFactor: + Forgetfulness of the previous centroids. + :param timeUnit: + Can be "batches" or "points". If points, then the decay factor + is raised to the power of number of new points and if batches, + then decay factor will be used as is. """ if not isinstance(data, RDD): raise TypeError("Data should be of an RDD, got %s." % type(data)) @@ -704,10 +789,17 @@ class StreamingKMeans(object): More details on how the centroids are updated are provided under the docs of StreamingKMeansModel. - :param k: int, number of clusters - :param decayFactor: float, forgetfulness of the previous centroids. - :param timeUnit: can be "batches" or "points". If points, then the - decayfactor is raised to the power of no. of new points. + :param k: + Number of clusters. + (default: 2) + :param decayFactor: + Forgetfulness of the previous centroids. + (default: 1.0) + :param timeUnit: + Can be "batches" or "points". If points, then the decay factor is + raised to the power of number of new points and if batches, then + decay factor will be used as is. + (default: "batches") .. versionadded:: 1.5.0 """ @@ -870,11 +962,13 @@ def describeTopics(self, maxTermsPerTopic=None): WARNING: If vocabSize and k are large, this can return a large object! - :param maxTermsPerTopic: Maximum number of terms to collect for each topic. - (default: vocabulary size) - :return: Array over topics. Each topic is represented as a pair of matching arrays: - (term indices, term weights in topic). - Each topic's terms are sorted in order of decreasing weight. + :param maxTermsPerTopic: + Maximum number of terms to collect for each topic. + (default: vocabulary size) + :return: + Array over topics. Each topic is represented as a pair of + matching arrays: (term indices, term weights in topic). + Each topic's terms are sorted in order of decreasing weight. """ if maxTermsPerTopic is None: topics = self.call("describeTopics") @@ -887,8 +981,10 @@ def describeTopics(self, maxTermsPerTopic=None): def load(cls, sc, path): """Load the LDAModel from disk. - :param sc: SparkContext - :param path: str, path to where the model is stored. + :param sc: + SparkContext. + :param path: + Path to where the model is stored. """ if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) @@ -909,17 +1005,38 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): """Train a LDA model. - :param rdd: RDD of data points - :param k: Number of clusters you want - :param maxIterations: Number of iterations. Default to 20 - :param docConcentration: Concentration parameter (commonly named "alpha") - for the prior placed on documents' distributions over topics ("theta"). - :param topicConcentration: Concentration parameter (commonly named "beta" or "eta") - for the prior placed on topics' distributions over terms. - :param seed: Random Seed - :param checkpointInterval: Period (in iterations) between checkpoints. - :param optimizer: LDAOptimizer used to perform the actual calculation. - Currently "em", "online" are supported. Default to "em". + :param rdd: + RDD of documents, which are tuples of document IDs and term + (word) count vectors. The term count vectors are "bags of + words" with a fixed-size vocabulary (where the vocabulary size + is the length of the vector). Document IDs must be unique + and >= 0. + :param k: + Number of topics to infer, i.e., the number of soft cluster + centers. + (default: 10) + :param maxIterations: + Maximum number of iterations allowed. + (default: 20) + :param docConcentration: + Concentration parameter (commonly named "alpha") for the prior + placed on documents' distributions over topics ("theta"). + (default: -1.0) + :param topicConcentration: + Concentration parameter (commonly named "beta" or "eta") for + the prior placed on topics' distributions over terms. + (default: -1.0) + :param seed: + Random seed for cluster initialization. Set as None to generate + seed based on system time. + (default: None) + :param checkpointInterval: + Period (in iterations) between checkpoints. + (default: 10) + :param optimizer: + LDAOptimizer used to perform the actual calculation. Currently + "em", "online" are supported. + (default: "em") """ model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, docConcentration, topicConcentration, seed, From 358300c795025735c3b2f96c5447b1b227d4abc1 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 2 Feb 2016 11:09:40 -0800 Subject: [PATCH 053/107] [SPARK-13056][SQL] map column would throw NPE if value is null Jira: https://issues.apache.org/jira/browse/SPARK-13056 Create a map like { "a": "somestring", "b": null} Query like SELECT col["b"] FROM t1; NPE would be thrown. Author: Daoyuan Wang Closes #10964 from adrian-wang/npewriter. --- .../expressions/complexTypeExtractors.scala | 15 +++++++++------ .../org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5256baaf432a2..9f2f82d68cca0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -218,7 +218,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0) { + if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) { null } else { baseValue.get(index, dataType) @@ -267,6 +267,7 @@ case class GetMapValue(child: Expression, key: Expression) val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() + val values = map.valueArray() var i = 0 var found = false @@ -278,10 +279,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if (!found) { + if (!found || values.isNullAt(i)) { null } else { - map.valueArray().get(i, dataType) + values.get(i, dataType) } } @@ -291,10 +292,12 @@ case class GetMapValue(child: Expression, key: Expression) val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") + val values = ctx.freshName("values") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); final ArrayData $keys = $eval1.keyArray(); + final ArrayData $values = $eval1.valueArray(); int $index = 0; boolean $found = false; @@ -307,10 +310,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if ($found) { - ${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)}; - } else { + if (!$found || $values.isNullAt($index)) { ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(values, dataType, index)}; } """ }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2b821c1056f56..79bfd4b44b70a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2055,6 +2055,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-13056: Null in map value causes NPE") { + val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") + withTempTable("maptest") { + df.registerTempTable("maptest") + // local optimization will by pass codegen code, so we should keep the filter `key=1` + checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring")) + checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null)) + } + } + test("hash function") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") withTempTable("tbl") { From b1835d727234fdff42aa8cadd17ddcf43b0bed15 Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Tue, 2 Feb 2016 11:16:24 -0800 Subject: [PATCH 054/107] [SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication Fixes problem and verifies fix by test suite. Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn and deduplicates SchemaUtils.appendColumn functions. Author: Grzegorz Chilkiewicz Closes #10741 from grzegorz-chilkiewicz/master. --- .../spark/ml/feature/StopWordsRemover.scala | 4 +--- .../org/apache/spark/ml/util/SchemaUtils.scala | 8 +++----- .../spark/ml/feature/StopWordsRemoverSuite.scala | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index b93c9ed382bdf..e53ef300f644b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String) val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") - val outputFields = schema.fields :+ - StructField($(outputCol), inputType, schema($(inputCol)).nullable) - StructType(outputFields) + SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index e71dd9eee03e3..76021ad8f4e65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -71,12 +71,10 @@ private[spark] object SchemaUtils { def appendColumn( schema: StructType, colName: String, - dataType: DataType): StructType = { + dataType: DataType, + nullable: Boolean = false): StructType = { if (colName.isEmpty) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Column $colName already exists.") - val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) - StructType(outputFields) + appendColumn(schema, StructField(colName, dataType, nullable)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index fb217e0c1de93..a5b24c18565b9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -89,4 +89,19 @@ class StopWordsRemoverSuite .setCaseSensitive(true) testDefaultReadWrite(t) } + + test("StopWordsRemover output column already exists") { + val outputCol = "expected" + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol(outputCol) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("The", "the", "swift"), Seq("swift")) + )).toDF("raw", outputCol) + + val thrown = intercept[IllegalArgumentException] { + testStopWordsRemover(remover, dataSet) + } + assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + } } From 7f6e3ec79b77400f558ceffa10b2af011962115f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Feb 2016 11:29:20 -0800 Subject: [PATCH 055/107] [SPARK-13138][SQL] Add "logical" package prefix for ddl.scala ddl.scala is defined in the execution package, and yet its reference of "UnaryNode" and "Command" are logical. This was fairly confusing when I was trying to understand the ddl code. Author: Reynold Xin Closes #11021 from rxin/SPARK-13138. --- .../spark/sql/execution/datasources/ddl.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 1554209be9891..a141b58d3d72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ @@ -32,7 +33,7 @@ import org.apache.spark.sql.types._ */ case class DescribeCommand( table: LogicalPlan, - isExtended: Boolean) extends LogicalPlan with Command { + isExtended: Boolean) extends LogicalPlan with logical.Command { override def children: Seq[LogicalPlan] = Seq.empty @@ -59,7 +60,7 @@ case class CreateTableUsing( temporary: Boolean, options: Map[String, String], allowExisting: Boolean, - managedIfNoPath: Boolean) extends LogicalPlan with Command { + managedIfNoPath: Boolean) extends LogicalPlan with logical.Command { override def output: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty @@ -67,8 +68,8 @@ case class CreateTableUsing( /** * A node used to support CTAS statements and saveAsTable for the data source API. - * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer - * can analyze the logical plan that will be used to populate the table. + * This node is a [[logical.UnaryNode]] instead of a [[logical.Command]] because we want the + * analyzer can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ case class CreateTableUsingAsSelect( @@ -79,7 +80,7 @@ case class CreateTableUsingAsSelect( bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan) extends logical.UnaryNode { override def output: Seq[Attribute] = Seq.empty[Attribute] } From be5dd881f1eff248224a92d57cfd1309cb3acf38 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 11:50:14 -0800 Subject: [PATCH 056/107] [SPARK-12913] [SQL] Improve performance of stat functions As benchmarked and discussed here: https://github.com/apache/spark/pull/10786/files#r50038294, benefits from codegen, the declarative aggregate function could be much faster than imperative one. Author: Davies Liu Closes #10960 from davies/stddev. --- .../catalyst/analysis/HiveTypeCoercion.scala | 18 +- .../aggregate/CentralMomentAgg.scala | 285 ++++++++---------- .../catalyst/expressions/aggregate/Corr.scala | 208 ++++--------- .../expressions/aggregate/Covariance.scala | 205 ++++--------- .../expressions/aggregate/Kurtosis.scala | 54 ---- .../expressions/aggregate/Skewness.scala | 53 ---- .../expressions/aggregate/Stddev.scala | 81 ----- .../expressions/aggregate/Variance.scala | 81 ----- .../spark/sql/catalyst/expressions/misc.scala | 18 ++ .../apache/spark/sql/execution/Window.scala | 6 +- .../aggregate/TungstenAggregate.scala | 1 - .../BenchmarkWholeStageCodegen.scala | 55 +++- .../execution/HiveCompatibilitySuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 17 +- 14 files changed, 331 insertions(+), 755 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 957ac89fa530d..57bdb164e1a0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -347,18 +347,12 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) + case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) + case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) + case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 30f602227b17d..9d2db45144817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -44,7 +42,7 @@ import org.apache.spark.sql.types._ * * @param child to compute central moments of. */ -abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { +abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate { /** * The central moment order to be computed. @@ -52,178 +50,161 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w protected def momentOrder: Int override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = true - override def dataType: DataType = DoubleType + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val avg = AttributeReference("avg", DoubleType, nullable = false)() + protected val m2 = AttributeReference("m2", DoubleType, nullable = false)() + protected val m3 = AttributeReference("m3", DoubleType, nullable = false)() + protected val m4 = AttributeReference("m4", DoubleType, nullable = false)() - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") + private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1) - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4)) - /** - * Size of aggregation buffer. - */ - private[this] val bufferSize = 5 + override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => - AttributeReference(s"M$i", DoubleType)() + override val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) } - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - // buffer offsets - private[this] val nOffset = mutableAggBufferOffset - private[this] val meanOffset = mutableAggBufferOffset + 1 - private[this] val secondMomentOffset = mutableAggBufferOffset + 2 - private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 - private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 - - // frequently used values for online updates - private[this] var delta = 0.0 - private[this] var deltaN = 0.0 - private[this] var delta2 = 0.0 - private[this] var deltaN2 = 0.0 - private[this] var n = 0.0 - private[this] var mean = 0.0 - private[this] var m2 = 0.0 - private[this] var m3 = 0.0 - private[this] var m4 = 0.0 + override val mergeExpressions: Seq[Expression] = { - /** - * Initialize all moments to zero. - */ - override def initialize(buffer: MutableRow): Unit = { - for (aggIndex <- 0 until bufferSize) { - buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val delta = avg.right - avg.left + val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) + val newAvg = avg.left + deltaN * n2 + + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2 + // `m3.right` is not available if momentOrder < 3 + val newM3 = if (momentOrder >= 3) { + m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) + + Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left) + } else { + Literal(0.0) } + // `m4.right` is not available if momentOrder < 4 + val newM4 = if (momentOrder >= 4) { + m4.left + m4.right + + deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) + + Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * m2.left) + + Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } +} - /** - * Update the central moments buffer. - */ - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val v = Cast(child, DoubleType).eval(input) - if (v != null) { - val updateValue = v match { - case d: Double => d - } - - n = buffer.getDouble(nOffset) - mean = buffer.getDouble(meanOffset) - - n += 1.0 - buffer.setDouble(nOffset, n) - delta = updateValue - mean - deltaN = delta / n - mean += deltaN - buffer.setDouble(meanOffset, mean) - - if (momentOrder >= 2) { - m2 = buffer.getDouble(secondMomentOffset) - m2 += delta * (delta - deltaN) - buffer.setDouble(secondMomentOffset, m2) - } - - if (momentOrder >= 3) { - delta2 = delta * delta - deltaN2 = deltaN * deltaN - m3 = buffer.getDouble(thirdMomentOffset) - m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) - buffer.setDouble(thirdMomentOffset, m3) - } - - if (momentOrder >= 4) { - m4 = buffer.getDouble(fourthMomentOffset) - m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + - delta * (delta * delta2 - deltaN * deltaN2) - buffer.setDouble(fourthMomentOffset, m4) - } - } +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + Sqrt(m2 / n)) } - /** - * Merge two central moment buffers. - */ - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val n1 = buffer1.getDouble(nOffset) - val n2 = buffer2.getDouble(inputAggBufferOffset) - val mean1 = buffer1.getDouble(meanOffset) - val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 - var secondMoment1 = 0.0 - var secondMoment2 = 0.0 + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + Sqrt(m2 / (n - Literal(1.0))))) + } - var thirdMoment1 = 0.0 - var thirdMoment2 = 0.0 + override def prettyName: String = "stddev_samp" +} - var fourthMoment1 = 0.0 - var fourthMoment2 = 0.0 +// Compute the population variance of a column +case class VariancePop(child: Expression) extends CentralMomentAgg(child) { - n = n1 + n2 - buffer1.setDouble(nOffset, n) - delta = mean2 - mean1 - deltaN = if (n == 0.0) 0.0 else delta / n - mean = mean1 + deltaN * n2 - buffer1.setDouble(mutableAggBufferOffset + 1, mean) + override protected def momentOrder = 2 - // higher order moments computed according to: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics - if (momentOrder >= 2) { - secondMoment1 = buffer1.getDouble(secondMomentOffset) - secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) - m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 - buffer1.setDouble(secondMomentOffset, m2) - } + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + m2 / n) + } - if (momentOrder >= 3) { - thirdMoment1 = buffer1.getDouble(thirdMomentOffset) - thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) - m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * - (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) - buffer1.setDouble(thirdMomentOffset, m3) - } + override def prettyName: String = "var_pop" +} - if (momentOrder >= 4) { - fourthMoment1 = buffer1.getDouble(fourthMomentOffset) - fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) - m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * - n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * - (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + - 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) - buffer1.setDouble(fourthMomentOffset, m4) - } +// Compute the sample variance of a column +case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + m2 / (n - Literal(1.0)))) } - /** - * Compute aggregate statistic from sufficient moments. - * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) - * needed to compute the aggregate stat. - */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any - - override final def eval(buffer: InternalRow): Any = { - val n = buffer.getDouble(nOffset) - val mean = buffer.getDouble(meanOffset) - val moments = Array.ofDim[Double](momentOrder + 1) - moments(0) = 1.0 - moments(1) = 0.0 - if (momentOrder >= 2) { - moments(2) = buffer.getDouble(secondMomentOffset) - } - if (momentOrder >= 3) { - moments(3) = buffer.getDouble(thirdMomentOffset) - } - if (momentOrder >= 4) { - moments(4) = buffer.getDouble(fourthMomentOffset) - } + override def prettyName: String = "var_samp" +} + +case class Skewness(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "skewness" + + override protected def momentOrder = 3 - getStatistic(n, mean, moments) + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(m2 === Literal(0.0), Literal(Double.NaN), + Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) } } + +case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 4 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(m2 === Literal(0.0), Literal(Double.NaN), + n * m4 / (m2 * m2) - Literal(3.0))) + } + + override def prettyName: String = "kurtosis" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index d25f3335ffd93..e6b8214ef25e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -29,165 +28,70 @@ import org.apache.spark.sql.types._ * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -case class Corr( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends ImperativeAggregate { - - def this(left: Expression, right: Expression) = - this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def children: Seq[Expression] = Seq(left, right) +case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate { + override def children: Seq[Expression] = Seq(x, y) override def nullable: Boolean = true - override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"corr requires that both arguments are double type, " + - s"not (${left.dataType}, ${right.dataType}).") - } + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + protected val xMk = AttributeReference("xMk", DoubleType, nullable = false)() + protected val yMk = AttributeReference("yMk", DoubleType, nullable = false)() + + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck, xMk, yMk) + + override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) + + override val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dxN = dx / newN + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dxN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + val newXMk = xMk + dx * (x - newXAvg) + val newYMk = yMk + dy * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk), + If(isNull, xMk, newXMk), + If(isNull, yMk, newYMk) + ) } - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = { - aggBufferAttributes.map(_.newInstance()) + override val mergeExpressions: Seq[Expression] = { + + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) } - override val aggBufferAttributes: Seq[AttributeReference] = Seq( - AttributeReference("xAvg", DoubleType)(), - AttributeReference("yAvg", DoubleType)(), - AttributeReference("Ck", DoubleType)(), - AttributeReference("MkX", DoubleType)(), - AttributeReference("MkY", DoubleType)(), - AttributeReference("count", LongType)()) - - // Local cache of mutableAggBufferOffset(s) that will be used in update and merge - private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 - private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 - private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 - private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 - private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 - - // Local cache of inputAggBufferOffset(s) that will be used in update and merge - private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 - private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 - private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 - private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 - private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def initialize(buffer: MutableRow): Unit = { - buffer.setDouble(mutableAggBufferOffset, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) - buffer.setLong(mutableAggBufferOffsetPlus5, 0L) + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + ck / Sqrt(xMk * yMk))) } - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - - if (leftEval != null && rightEval != null) { - val x = leftEval.asInstanceOf[Double] - val y = rightEval.asInstanceOf[Double] - - var xAvg = buffer.getDouble(mutableAggBufferOffset) - var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer.getLong(mutableAggBufferOffsetPlus5) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - MkX += deltaX * (x - xAvg) - MkY += deltaY * (y - yAvg) - - buffer.setDouble(mutableAggBufferOffset, xAvg) - buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer.setLong(mutableAggBufferOffsetPlus5, count) - } - } - - // Merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) - - // We only go to merge two buffers if there is at least one record aggregated in buffer2. - // We don't need to check count in buffer1 because if count2 is more than zero, totalCount - // is more than zero too, then we won't get a divide by zero exception. - if (count2 > 0) { - var xAvg = buffer1.getDouble(mutableAggBufferOffset) - var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer1.getLong(mutableAggBufferOffsetPlus5) - - val xAvg2 = buffer2.getDouble(inputAggBufferOffset) - val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) - val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) - val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) - val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) - - val totalCount = count + count2 - val deltaX = xAvg - xAvg2 - val deltaY = yAvg - yAvg2 - Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 - xAvg = (xAvg * count + xAvg2 * count2) / totalCount - yAvg = (yAvg * count + yAvg2 * count2) / totalCount - MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 - MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 - count = totalCount - - buffer1.setDouble(mutableAggBufferOffset, xAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer1.setLong(mutableAggBufferOffsetPlus5, count) - } - } - - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(mutableAggBufferOffsetPlus5) - if (count > 0) { - val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - val corr = Ck / math.sqrt(MkX * MkY) - if (corr.isNaN) { - null - } else { - corr - } - } else { - null - } - } + override def prettyName: String = "corr" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index f53b01be2a0d5..c175a8c4c77b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -17,182 +17,79 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Compute the covariance between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. - * */ -abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate - with Serializable { - override def children: Seq[Expression] = Seq(left, right) +abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate { + override def children: Seq[Expression] = Seq(x, y) override def nullable: Boolean = true - override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"covariance requires that both arguments are double type, " + - s"not (${left.dataType}, ${right.dataType}).") - } - } - - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = { - aggBufferAttributes.map(_.newInstance()) - } - - override val aggBufferAttributes: Seq[AttributeReference] = Seq( - AttributeReference("xAvg", DoubleType)(), - AttributeReference("yAvg", DoubleType)(), - AttributeReference("Ck", DoubleType)(), - AttributeReference("count", LongType)()) - - // Local cache of mutableAggBufferOffset(s) that will be used in update and merge - val xAvgOffset = mutableAggBufferOffset - val yAvgOffset = mutableAggBufferOffset + 1 - val CkOffset = mutableAggBufferOffset + 2 - val countOffset = mutableAggBufferOffset + 3 - - // Local cache of inputAggBufferOffset(s) that will be used in update and merge - val inputXAvgOffset = inputAggBufferOffset - val inputYAvgOffset = inputAggBufferOffset + 1 - val inputCkOffset = inputAggBufferOffset + 2 - val inputCountOffset = inputAggBufferOffset + 3 - - override def initialize(buffer: MutableRow): Unit = { - buffer.setDouble(xAvgOffset, 0.0) - buffer.setDouble(yAvgOffset, 0.0) - buffer.setDouble(CkOffset, 0.0) - buffer.setLong(countOffset, 0L) - } - - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - - if (leftEval != null && rightEval != null) { - val x = leftEval.asInstanceOf[Double] - val y = rightEval.asInstanceOf[Double] - - var xAvg = buffer.getDouble(xAvgOffset) - var yAvg = buffer.getDouble(yAvgOffset) - var Ck = buffer.getDouble(CkOffset) - var count = buffer.getLong(countOffset) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - - buffer.setDouble(xAvgOffset, xAvg) - buffer.setDouble(yAvgOffset, yAvg) - buffer.setDouble(CkOffset, Ck) - buffer.setLong(countOffset, count) - } + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck) + + override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) + + override lazy val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) } - // Merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputCountOffset) - - // We only go to merge two buffers if there is at least one record aggregated in buffer2. - // We don't need to check count in buffer1 because if count2 is more than zero, totalCount - // is more than zero too, then we won't get a divide by zero exception. - if (count2 > 0) { - var xAvg = buffer1.getDouble(xAvgOffset) - var yAvg = buffer1.getDouble(yAvgOffset) - var Ck = buffer1.getDouble(CkOffset) - var count = buffer1.getLong(countOffset) + override val mergeExpressions: Seq[Expression] = { - val xAvg2 = buffer2.getDouble(inputXAvgOffset) - val yAvg2 = buffer2.getDouble(inputYAvgOffset) - val Ck2 = buffer2.getDouble(inputCkOffset) + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val totalCount = count + count2 - val deltaX = xAvg - xAvg2 - val deltaY = yAvg - yAvg2 - Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 - xAvg = (xAvg * count + xAvg2 * count2) / totalCount - yAvg = (yAvg * count + yAvg2 * count2) / totalCount - count = totalCount - - buffer1.setDouble(xAvgOffset, xAvg) - buffer1.setDouble(yAvgOffset, yAvg) - buffer1.setDouble(CkOffset, Ck) - buffer1.setLong(countOffset, count) - } + Seq(newN, newXAvg, newYAvg, newCk) } } -case class CovSample( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends Covariance(left, right) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(countOffset) - if (count > 1) { - val Ck = buffer.getDouble(CkOffset) - val cov = Ck / (count - 1) - if (cov.isNaN) { - null - } else { - cov - } - } else { - null - } +case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + ck / n) } + override def prettyName: String = "covar_pop" } -case class CovPopulation( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends Covariance(left, right) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(countOffset) - if (count > 0) { - val Ck = buffer.getDouble(CkOffset) - if (Ck.isNaN) { - null - } else { - Ck / count - } - } else { - null - } +case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + ck / (n - Literal(1.0)))) } + override def prettyName: String = "covar_samp" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala deleted file mode 100644 index c2bf2cb94116c..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class Kurtosis(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "kurtosis" - - override protected val momentOrder = 4 - - // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m4 = moments(4) - - if (n == 0.0) { - null - } else if (m2 == 0.0) { - Double.NaN - } else { - n * m4 / (m2 * m2) - 3.0 - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala deleted file mode 100644 index 9411bcea2539a..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class Skewness(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "skewness" - - override protected val momentOrder = 3 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m3 = moments(3) - - if (n == 0.0) { - null - } else if (m2 == 0.0) { - Double.NaN - } else { - math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala deleted file mode 100644 index eec79a9033e36..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class StddevSamp(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "stddev_samp" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else if (n == 1.0) { - Double.NaN - } else { - math.sqrt(moments(2) / (n - 1.0)) - } - } -} - -case class StddevPop( - child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "stddev_pop" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else { - math.sqrt(moments(2) / n) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala deleted file mode 100644 index cf3a740305391..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class VarianceSamp(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "var_samp" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else if (n == 1.0) { - Double.NaN - } else { - moments(2) / (n - 1.0) - } - } -} - -case class VariancePop( - child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "var_pop" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else { - moments(2) / n - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 36e1fa1176d22..f4ccadd9c563e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -424,3 +424,21 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } } + +/** + * Print the result of an expression to stderr (used for debugging codegen). + */ +case class PrintToStderr(child: Expression) extends UnaryExpression { + + override def dataType: DataType = child.dataType + + protected override def nullSafeEval(input: Any): Any = input + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + nullSafeCodeGen(ctx, ev, c => + s""" + | System.err.println("Result of ${child.simpleString} is " + $c); + | ${ev.value} = $c; + """.stripMargin) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 26a7340f1ae10..84154a47de393 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -198,7 +198,8 @@ case class Window( functions, ordinal, child.output, - (expressions, schema) => newMutableProjection(expressions, schema)) + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) // Create the factory val factory = key match { @@ -210,7 +211,8 @@ case class Window( ordinal, functions, child.output, - (expressions, schema) => newMutableProjection(expressions, schema), + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled), offset) // Growing Frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 57db7262fdaf3..a8a81d6d6574e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -240,7 +240,6 @@ case class TungstenAggregate( | ${bufVars(i).value} = ${ev.value}; """.stripMargin } - s""" | // do aggregate | ${aggVals.map(_.code).mkString("\n")} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 2f09c8a114bc5..1ccf0e3d0656c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -59,6 +59,55 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } + def testStatFunctions(values: Int): Unit = { + + val benchmark = new Benchmark("stat functions", values) + + benchmark.addCase("stddev w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + } + + benchmark.addCase("stddev w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + } + + benchmark.addCase("kurtosis w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + } + + benchmark.addCase("kurtosis w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + } + + + /** + Using ImperativeAggregate (as implemented in Spark 1.6): + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + stddev w/o codegen 2019.04 10.39 1.00 X + stddev w codegen 2097.29 10.00 0.96 X + kurtosis w/o codegen 2108.99 9.94 0.96 X + kurtosis w codegen 2090.69 10.03 0.97 X + + Using DeclarativeAggregate: + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + stddev w/o codegen 989.22 21.20 1.00 X + stddev w codegen 352.35 59.52 2.81 X + kurtosis w/o codegen 3636.91 5.77 0.27 X + kurtosis w codegen 369.25 56.79 2.68 X + */ + benchmark.run() + } + def testAggregateWithKey(values: Int): Unit = { val benchmark = new Benchmark("Aggregate with keys", values) @@ -147,8 +196,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } - test("benchmark") { - // testWholeStage(1024 * 1024 * 200) + // These benchmark are skipped in normal build + ignore("benchmark") { + // testWholeStage(200 << 20) + // testStddev(20 << 20) // testAggregateWithKey(20 << 20) // testBytesToBytesMap(1024 * 1024 * 50) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 554d47d651aef..61b73fa557144 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -325,6 +325,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "drop_partitions_ignore_protection", "protectmode", + // Hive returns null rather than NaN when n = 1 + "udaf_covar_samp", + // Spark parser treats numerical literals differently: it creates decimals instead of doubles. "udf_abs", "udf_format_number", @@ -881,7 +884,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "type_widening", "udaf_collect_set", "udaf_covar_pop", - "udaf_covar_samp", "udaf_histogram_numeric", "udf2", "udf5", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7a9ed1eaf3dbc..caf1db9ad0855 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -798,7 +798,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """ |SELECT corr(b, c) FROM covar_tab WHERE a = 3 """.stripMargin), - Row(null) :: Nil) + Row(Double.NaN) :: Nil) checkAnswer( sqlContext.sql( @@ -807,10 +807,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """.stripMargin), Row(1, null) :: Row(2, null) :: - Row(3, null) :: - Row(4, null) :: - Row(5, null) :: - Row(6, null) :: Nil) + Row(3, Double.NaN) :: + Row(4, Double.NaN) :: + Row(5, Double.NaN) :: + Row(6, Double.NaN) :: Nil) val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) @@ -841,11 +841,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // one row test val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") - val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0) - assert(cov_samp3 == null) - - val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) - assert(cov_pop3 == 0.0) + checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(Double.NaN)) + checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0)) } test("no aggregation function (SPARK-11486)") { From d0df2ca40953ba581dce199798a168af01283cdc Mon Sep 17 00:00:00 2001 From: Gabriele Nizzoli Date: Tue, 2 Feb 2016 13:20:01 -0800 Subject: [PATCH 057/107] [SPARK-13121][STREAMING] java mapWithState mishandles scala Option Already merged into 1.6 branch, this PR is to commit to master the same change Author: Gabriele Nizzoli Closes #11028 from gabrielenizzoli/patch-1. --- .../src/main/scala/org/apache/spark/streaming/StateSpec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 66f646d7dc136..e6724feaee105 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -221,7 +221,7 @@ object StateSpec { mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): StateSpec[KeyType, ValueType, StateType, MappedType] = { val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { - mappingFunction.call(k, Optional.ofNullable(v.get), s) + mappingFunction.call(k, JavaUtils.optionToOptional(v), s) } StateSpec.function(wrappedFunc) } From b377b03531d21b1d02a8f58b3791348962e1f31b Mon Sep 17 00:00:00 2001 From: "Kevin (Sangwoo) Kim" Date: Tue, 2 Feb 2016 13:24:09 -0800 Subject: [PATCH 058/107] [DOCS] Update StructType.scala The example will throw error like :20: error: not found: value StructType Need to add this line: import org.apache.spark.sql.types._ Author: Kevin (Sangwoo) Kim Closes #10141 from swkimme/patch-1. --- .../src/main/scala/org/apache/spark/sql/types/StructType.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index c9e7e7fe633b8..e797d83cb05be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParse * Example: * {{{ * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ * * val struct = * StructType( From 6de6a97728408ee2619006decf2267cc43eeea0d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 16:24:31 -0800 Subject: [PATCH 059/107] [SPARK-13150] [SQL] disable two flaky tests Author: Davies Liu Closes #11037 from davies/disable_flaky. --- .../sql/hive/thriftserver/HiveThriftServer2Suites.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ba3b26e1b7d49..9860e40fe8546 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -488,7 +488,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } - test("SPARK-11595 ADD JAR with input path having URL scheme") { + // TODO: enable this + ignore("SPARK-11595 ADD JAR with input path having URL scheme") { withJdbcStatement { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" @@ -546,7 +547,8 @@ class SingleSessionSuite extends HiveThriftJdbcTest { override protected def extraConf: Seq[String] = "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil - test("test single session") { + // TODO: enable this + ignore("test single session") { withMultipleConnectionJdbcStatement( { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" From 672032d0ab1e43bc5a25cecdb1b96dfd35c39778 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Feb 2016 08:26:35 +0800 Subject: [PATCH 060/107] [SPARK-13020][SQL][TEST] fix random generator for map type when we generate map, we first randomly pick a length, then create a seq of key value pair with the expected length, and finally call `toMap`. However, `toMap` will remove all duplicated keys, which makes the actual map size much less than we expected. This PR fixes this problem by put keys in a set first, to guarantee we have enough keys to build a map with expected length. Author: Wenchen Fan Closes #10930 from cloud-fan/random-generator. --- .../apache/spark/sql/RandomDataGenerator.scala | 18 ++++++++++++++---- .../spark/sql/RandomDataGeneratorSuite.scala | 11 +++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 55efea80d1a4d..7c173cbceefed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -47,9 +47,9 @@ object RandomDataGenerator { */ private val PROBABILITY_OF_NULL: Float = 0.1f - private val MAX_STR_LEN: Int = 1024 - private val MAX_ARR_SIZE: Int = 128 - private val MAX_MAP_SIZE: Int = 128 + final val MAX_STR_LEN: Int = 1024 + final val MAX_ARR_SIZE: Int = 128 + final val MAX_MAP_SIZE: Int = 128 /** * Helper function for constructing a biased random number generator which returns "interesting" @@ -208,7 +208,17 @@ object RandomDataGenerator { forType(valueType, nullable = valueContainsNull, rand) ) yield { () => { - Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap + val length = rand.nextInt(MAX_MAP_SIZE) + val keys = scala.collection.mutable.HashSet(Seq.fill(length)(keyGenerator()): _*) + // In case the number of different keys is not enough, set a max iteration to avoid + // infinite loop. + var count = 0 + while (keys.size < length && count < MAX_MAP_SIZE) { + keys += keyGenerator() + count += 1 + } + val values = Seq.fill(keys.size)(valueGenerator()) + keys.zip(values).toMap } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index b8ccdf7516d82..9fba7924e9542 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -95,4 +95,15 @@ class RandomDataGeneratorSuite extends SparkFunSuite { } } + test("check size of generated map") { + val mapType = MapType(IntegerType, IntegerType) + for (seed <- 1 to 1000) { + val generator = RandomDataGenerator.forType( + mapType, nullable = false, rand = new Random(seed)).get + val maps = Seq.fill(100)(generator().asInstanceOf[Map[Int, Int]]) + val expectedTotalElements = 100 / 2 * RandomDataGenerator.MAX_MAP_SIZE + val deviation = math.abs(maps.map(_.size).sum - expectedTotalElements) + assert(deviation.toDouble / expectedTotalElements < 2e-1) + } + } } From 21112e8a14c042ccef4312079672108a1082a95e Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 2 Feb 2016 16:33:21 -0800 Subject: [PATCH 061/107] [SPARK-12992] [SQL] Update parquet reader to support more types when decoding to ColumnarBatch. This patch implements support for more types when doing the vectorized decode. There are a few more types remaining but they should be very straightforward after this. This code has a few copy and paste pieces but they are difficult to eliminate due to performance considerations. Specifically, this patch adds support for: - String, Long, Byte types - Dictionary encoding for those types. Author: Nong Li Closes #10908 from nongli/spark-12992. --- .../parquet/UnsafeRowParquetRecordReader.java | 146 ++++++++++++++-- .../parquet/VectorizedPlainValuesReader.java | 45 ++++- .../parquet/VectorizedRleValuesReader.java | 160 +++++++++++++++++- .../parquet/VectorizedValuesReader.java | 5 + .../execution/vectorized/ColumnVector.java | 7 +- .../parquet/ParquetEncodingSuite.scala | 82 +++++++++ 6 files changed, 424 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 17adfec32192f..b5dddb9f11b22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.List; +import org.apache.commons.lang.NotImplementedException; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; @@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -207,13 +209,7 @@ public boolean nextBatch() throws IOException { int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { - switch (columnReaders[i].descriptor.getType()) { - case INT32: - columnReaders[i].readIntBatch(num, columnarBatch.column(i)); - break; - default: - throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType()); - } + columnReaders[i].readBatch(num, columnarBatch.column(i)); } rowsReturned += num; columnarBatch.setNumRows(num); @@ -237,7 +233,8 @@ private void initializeInternal() throws IOException { // TODO: Be extremely cautious in what is supported. Expand this. if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && - originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE && + originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && @@ -464,6 +461,11 @@ private final class ColumnReader { */ private boolean useDictionary; + /** + * If useDictionary is true, the staging vector used to decode the ids. + */ + private ColumnVector dictionaryIds; + /** * Maximum definition level for this column. */ @@ -587,9 +589,8 @@ private boolean next() throws IOException { /** * Reads `total` values from this columnReader into column. - * TODO: implement the other encodings. */ - private void readIntBatch(int total, ColumnVector column) throws IOException { + private void readBatch(int total, ColumnVector column) throws IOException { int rowId = 0; while (total > 0) { // Compute the number of values we want to read in this page. @@ -599,21 +600,134 @@ private void readIntBatch(int total, ColumnVector column) throws IOException { leftInPage = (int)(endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); - defColumn.readIntegers( - num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0); - - // Remap the values if it is dictionary encoded. if (useDictionary) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(column.getInt(i))); + // Data is dictionary encoded. We will vector decode the ids and then resolve the values. + if (dictionaryIds == null) { + dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(total); + } + // Read and decode dictionary ids. + readIntBatch(rowId, num, dictionaryIds); + decodeDictionaryIds(rowId, num, column); + } else { + switch (descriptor.getType()) { + case INT32: + readIntBatch(rowId, num, column); + break; + case INT64: + readLongBatch(rowId, num, column); + break; + case BINARY: + readBinaryBatch(rowId, num, column); + break; + default: + throw new IOException("Unsupported type: " + descriptor.getType()); } } + valuesRead += num; rowId += num; total -= num; } } + /** + * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. + */ + private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + switch (descriptor.getType()) { + case INT32: + if (column.dataType() == DataTypes.IntegerType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ByteType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + + case INT64: + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } + break; + + case BINARY: + // TODO: this is incredibly inefficient as it blows up the dictionary right here. We + // need to do this better. We should probably add the dictionary data to the ColumnVector + // and reuse it across batches. This should mean adding a ByteArray would just update + // the length and offset. + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } + break; + + default: + throw new NotImplementedException("Unsupported type: " + descriptor.getType()); + } + + if (dictionaryIds.numNulls() > 0) { + // Copy the NULLs over. + // TODO: we can improve this by decoding the NULLs directly into column. This would + // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then + // just do the ID remapping as above. + for (int i = 0; i < num; ++i) { + if (dictionaryIds.getIsNull(rowId + i)) { + column.putNull(rowId + i); + } + } + } + } + + /** + * For all the read*Batch functions, reads `num` values from this columnReader into column. It + * is guaranteed that num is smaller than the number of values left in the current page. + */ + + private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.IntegerType) { + defColumn.readIntegers( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0); + } else if (column.dataType() == DataTypes.ByteType) { + defColumn.readBytes( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.LongType) { + defColumn.readLongs( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.isArray()) { + defColumn.readBinarys( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readPage() throws IOException { DataPage page = pageReader.readPage(); // TODO: Why is this a visitor? diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index dac0c52ebd2cf..cec2418e46030 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -18,10 +18,13 @@ import java.io.IOException; +import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.unsafe.Platform; +import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. @@ -52,15 +55,53 @@ public void skip(int n) { } @Override - public void readIntegers(int total, ColumnVector c, int rowId) { + public final void readIntegers(int total, ColumnVector c, int rowId) { c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 4 * total; } @Override - public int readInteger() { + public final void readLongs(int total, ColumnVector c, int rowId) { + c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 8 * total; + } + + @Override + public final void readBytes(int total, ColumnVector c, int rowId) { + for (int i = 0; i < total; i++) { + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + c.putInt(rowId + i, buffer[offset]); + offset += 4; + } + } + + @Override + public final int readInteger() { int v = Platform.getInt(buffer, offset); offset += 4; return v; } + + @Override + public final long readLong() { + long v = Platform.getLong(buffer, offset); + offset += 8; + return v; + } + + @Override + public final byte readByte() { + return (byte)readInteger(); + } + + @Override + public final void readBinary(int total, ColumnVector v, int rowId) { + for (int i = 0; i < total; i++) { + int len = readInteger(); + int start = offset; + offset += len; + v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + } + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 493ec9deed499..9bfd74db38766 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -17,12 +17,16 @@ package org.apache.spark.sql.execution.datasources.parquet; +import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.Preconditions; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; import org.apache.parquet.column.values.bitpacking.Packer; import org.apache.parquet.io.ParquetDecodingException; +import org.apache.parquet.io.api.Binary; + +import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -35,7 +39,8 @@ * - Definition/Repetition levels * - Dictionary ids. */ -public final class VectorizedRleValuesReader extends ValuesReader { +public final class VectorizedRleValuesReader extends ValuesReader + implements VectorizedValuesReader { // Current decoding mode. The encoded data contains groups of either run length encoded data // (RLE) or bit packed data. Each group contains a header that indicates which group it is and // the number of values in the group. @@ -121,6 +126,7 @@ public int readValueDictionaryId() { return readInteger(); } + @Override public int readInteger() { if (this.currentCount == 0) { this.readNextGroup(); } @@ -138,7 +144,9 @@ public int readInteger() { /** * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader * reads the definition levels and then will read from `data` for the non-null values. - * If the value is null, c will be populated with `nullValue`. + * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only + * necessary for readIntegers because we also use it to decode dictionaryIds and want to make + * sure it always has a value in range. * * This is a batched version of this logic: * if (this.readInt() == level) { @@ -180,6 +188,154 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, } } + // TODO: can this code duplication be removed without a perf penalty? + public void readBytes(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readBytes(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putByte(rowId + i, data.readByte()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readLongs(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readLongs(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putLong(rowId + i, data.readLong()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readBinarys(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + c.putNotNulls(rowId, n); + data.readBinary(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putNotNull(rowId + i); + data.readBinary(1, c, rowId); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + + // The RLE reader implements the vectorized decoding interface when used to decode dictionary + // IDs. This is different than the above APIs that decodes definitions levels along with values. + // Since this is only used to decode dictionary IDs, only decoding integers is supported. + @Override + public void readIntegers(int total, ColumnVector c, int rowId) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + c.putInts(rowId, n, currentValue); + break; + case PACKED: + c.putInts(rowId, n, currentBuffer, currentBufferIdx); + currentBufferIdx += n; + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + @Override + public byte readByte() { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBytes(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readLongs(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBinary(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void skip(int n) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + /** * Reads the next varint encoded int. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 49a9ed83d590a..b6ec7311c564a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -24,12 +24,17 @@ * TODO: merge this into parquet-mr. */ public interface VectorizedValuesReader { + byte readByte(); int readInteger(); + long readLong(); /* * Reads `total` values into `c` start at `c[rowId]` */ + void readBytes(int total, ColumnVector c, int rowId); void readIntegers(int total, ColumnVector c, int rowId); + void readLongs(int total, ColumnVector c, int rowId); + void readBinary(int total, ColumnVector c, int rowId); // TODO: add all the other parquet types. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a5bc506a65ac2..0514252a8e53d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -763,7 +763,12 @@ public final int appendStruct(boolean isNull) { /** * Returns the elements appended. */ - public int getElementsAppended() { return elementsAppended; } + public final int getElementsAppended() { return elementsAppended; } + + /** + * Returns true if this column is an array. + */ + public final boolean isArray() { return resultArray != null; } /** * Maximum number of rows that can be stored in this column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala new file mode 100644 index 0000000000000..cef6b79a094d1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils +import org.apache.spark.sql.test.SharedSQLContext + +// TODO: this needs a lot more testing but it's currently not easy to test with the parquet +// writer abstractions. Revisit. +class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext { + import testImplicits._ + + val ROW = ((1).toByte, 2, 3L, "abc") + val NULL_ROW = ( + null.asInstanceOf[java.lang.Byte], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[String]) + + test("All Types Dictionary") { + (1 :: 1000 :: Nil).foreach { n => { + withTempPath { dir => + List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new UnsafeRowParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getByte(i) == 1) + assert(batch.column(1).getInt(i) == 2) + assert(batch.column(2).getLong(i) == 3) + assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + i += 1 + } + reader.close() + } + }} + } + + test("All Types Null") { + (1 :: 100 :: Nil).foreach { n => { + withTempPath { dir => + val data = List.fill(n)(NULL_ROW).toDF + data.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new UnsafeRowParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getIsNull(i)) + assert(batch.column(1).getIsNull(i)) + assert(batch.column(2).getIsNull(i)) + assert(batch.column(3).getIsNull(i)) + i += 1 + } + reader.close() + }} + } + } +} From ff71261b651a7b289ea2312abd6075da8b838ed9 Mon Sep 17 00:00:00 2001 From: Adam Budde Date: Tue, 2 Feb 2016 19:35:33 -0800 Subject: [PATCH 062/107] [SPARK-13122] Fix race condition in MemoryStore.unrollSafely() https://issues.apache.org/jira/browse/SPARK-13122 A race condition can occur in MemoryStore's unrollSafely() method if two threads that return the same value for currentTaskAttemptId() execute this method concurrently. This change makes the operation of reading the initial amount of unroll memory used, performing the unroll, and updating the associated memory maps atomic in order to avoid this race condition. Initial proposed fix wraps all of unrollSafely() in a memoryManager.synchronized { } block. A cleaner approach might be introduce a mechanism that synchronizes based on task attempt ID. An alternative option might be to track unroll/pending unroll memory based on block ID rather than task attempt ID. Author: Adam Budde Closes #11012 from budde/master. --- .../org/apache/spark/storage/MemoryStore.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 76aaa782b9524..024b660ce6a7b 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -255,8 +255,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this task, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisTask + // Keep track of pending unroll memory reserved by this method. + var pendingMemoryReserved = 0L // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] @@ -266,6 +266,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") + } else { + pendingMemoryReserved += initialMemoryThreshold } // Unroll this block safely, checking whether we have exceeded our threshold periodically @@ -278,6 +280,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest) + if (keepUnrolling) { + pendingMemoryReserved += amountToRequest + } // New threshold is currentSize * memoryGrowthFactor memoryThreshold += amountToRequest } @@ -304,10 +309,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo // release the unroll memory yet. Instead, we transfer it to pending unroll memory // so `tryToPut` can further transfer it to normal storage memory later. // TODO: we can probably express this without pending unroll memory (SPARK-10907) - val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved - unrollMemoryMap(taskAttemptId) -= amountToTransferToPending + unrollMemoryMap(taskAttemptId) -= pendingMemoryReserved pendingUnrollMemoryMap(taskAttemptId) = - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + pendingMemoryReserved } } else { // Otherwise, if we return an iterator, we can only release the unroll memory when From 99a6e3c1e8d580ce1cc497bd9362eaf16c597f77 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 19:47:44 -0800 Subject: [PATCH 063/107] [SPARK-12951] [SQL] support spilling in generated aggregate This PR add spilling support for generated TungstenAggregate. If spilling happened, it's not that bad to do the iterator based sort-merge-aggregate (not generated). The changes will be covered by TungstenAggregationQueryWithControlledFallbackSuite Author: Davies Liu Closes #10998 from davies/gen_spilling. --- .../aggregate/TungstenAggregate.scala | 172 +++++++++++++++--- 1 file changed, 142 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a8a81d6d6574e..f61db8594dab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -258,6 +258,7 @@ case class TungstenAggregate( // The name for HashMap private var hashMapTerm: String = _ + private var sorterTerm: String = _ /** * This is called by generated Java class, should be public. @@ -286,39 +287,98 @@ case class TungstenAggregate( GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) } - /** - * Update peak execution memory, called in generated Java class. + * Called by generated Java class to finish the aggregate and return a KVIterator. */ - def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { + def finishAggregate( + hashMap: UnsafeFixedWidthAggregationMap, + sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { + + // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val peakMemory = Math.max(mapMemory, sorterMemory) val metrics = TaskContext.get().taskMetrics() - metrics.incPeakExecutionMemory(mapMemory) - } + metrics.incPeakExecutionMemory(peakMemory) - private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + if (sorter == null) { + // not spilled + return hashMap.iterator() + } - // create hashMap - val thisPlan = ctx.addReferenceObj("plan", this) - hashMapTerm = ctx.freshName("hashMap") - val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + // merge the final hashMap into sorter + sorter.merge(hashMap.destructAndCreateExternalSorter()) + hashMap.free() + val sortedIter = sorter.sortedIterator() + + // Create a KVIterator based on the sorted iterator. + new KVIterator[UnsafeRow, UnsafeRow] { + + // Create a MutableProjection to merge the rows of same key together + val mergeExpr = declFunctions.flatMap(_.mergeExpressions) + val mergeProjection = newMutableProjection( + mergeExpr, + bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + subexpressionEliminationEnabled)() + val joinedRow = new JoinedRow() + + var currentKey: UnsafeRow = null + var currentRow: UnsafeRow = null + var nextKey: UnsafeRow = if (sortedIter.next()) { + sortedIter.getKey + } else { + null + } - // Create a name for iterator from HashMap - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + override def next(): Boolean = { + if (nextKey != null) { + currentKey = nextKey.copy() + currentRow = sortedIter.getValue.copy() + nextKey = null + // use the first row as aggregate buffer + mergeProjection.target(currentRow) + + // merge the following rows with same key together + var findNextGroup = false + while (!findNextGroup && sortedIter.next()) { + val key = sortedIter.getKey + if (currentKey.equals(key)) { + mergeProjection(joinedRow(currentRow, sortedIter.getValue)) + } else { + // We find a new group. + findNextGroup = true + nextKey = key + } + } + + true + } else { + false + } + } - // generate code for output - val keyTerm = ctx.freshName("aggKey") - val bufferTerm = ctx.freshName("aggBuffer") - val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + override def getKey: UnsafeRow = currentKey + override def getValue: UnsafeRow = currentRow + override def close(): Unit = { + sortedIter.close() + } + } + } + + /** + * Generate the code for output. + */ + private def generateResultCode( + ctx: CodegenContext, + keyTerm: String, + bufferTerm: String, + plan: String): String = { + if (modes.contains(Final) || modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null ctx.INPUT_ROW = keyTerm val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).gen(ctx) + BoundReference(i, e.dataType, e.nullable).gen(ctx) } ctx.INPUT_ROW = bufferTerm val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => @@ -348,7 +408,7 @@ case class TungstenAggregate( // This should be the last operator in a stage, we should output UnsafeRow directly val joinerTerm = ctx.freshName("unsafeRowJoiner") ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, - s"$joinerTerm = $thisPlan.createUnsafeJoiner();") + s"$joinerTerm = $plan.createUnsafeJoiner();") val resultRow = ctx.freshName("resultRow") s""" UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); @@ -367,6 +427,23 @@ case class TungstenAggregate( ${consume(ctx, eval)} """ } + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + sorterTerm = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, @@ -374,10 +451,15 @@ case class TungstenAggregate( private void $doAgg() throws java.io.IOException { ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - $iterTerm = $hashMapTerm.iterator(); + $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); } """) + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + s""" if (!$initAgg) { $initAgg = true; @@ -391,8 +473,10 @@ case class TungstenAggregate( $outputCode } - $thisPlan.updatePeakMemory($hashMapTerm); - $hashMapTerm.free(); + $iterTerm.close(); + if ($sorterTerm == null) { + $hashMapTerm.free(); + } """ } @@ -425,14 +509,42 @@ case class TungstenAggregate( ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) } + val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) { + val countTerm = ctx.freshName("fallbackCounter") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;") + } else { + ("true", "", "") + } + + // We try to do hash map based in-memory aggregation first. If there is not enough memory (the + // hash map will return null for new key), we spill the hash map to disk to free memory, then + // continue to do in-memory aggregation and spilling until all the rows had been processed. + // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" // generate grouping key ${keyCode.code} - UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + UnsafeRow $buffer = null; + if ($checkFallback) { + // try to get the buffer from hash map + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + } if ($buffer == null) { - // failed to allocate the first page - throw new OutOfMemoryError("No enough memory for aggregation"); + if ($sorterTerm == null) { + $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + } else { + $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + } + $resetCoulter + // the hash map had be spilled, it should have enough memory now, + // try to allocate buffer again. + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } } + $incCounter // evaluate aggregate function ${evals.map(_.code).mkString("\n")} From 0557146619868002e2f7ec3c121c30bbecc918fc Mon Sep 17 00:00:00 2001 From: Imran Younus Date: Tue, 2 Feb 2016 20:38:53 -0800 Subject: [PATCH 064/107] [SPARK-12732][ML] bug fix in linear regression train Fixed the bug in linear regression train for the case when the target variable is constant. The two cases for `fitIntercept=true` or `fitIntercept=false` should be treated differently. Author: Imran Younus Closes #10702 from iyounus/SPARK-12732_bug_fix_in_linear_regression_train. --- .../ml/regression/LinearRegression.scala | 66 ++++++----- .../ml/regression/LinearRegressionSuite.scala | 105 ++++++++++++++++++ 2 files changed, 146 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c54e08b2ad9a5..e253f25c0ea65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -219,33 +219,49 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val yMean = ySummarizer.mean(0) - val yStd = math.sqrt(ySummarizer.variance(0)) - - // If the yStd is zero, then the intercept is yMean with zero coefficient; - // as a result, training is not needed. - if (yStd == 0.0) { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + - s"training is not needed.") - if (handlePersistence) instances.unpersist() - val coefficients = Vectors.sparse(numFeatures, Seq()) - val intercept = yMean - - val model = new LinearRegressionModel(uid, coefficients, intercept) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - model, - Array(0D), - $(featuresCol), - Array(0D)) - return copyValues(model.setSummary(trainingSummary)) + val rawYStd = math.sqrt(ySummarizer.variance(0)) + if (rawYStd == 0.0) { + if ($(fitIntercept) || yMean==0.0) { + // If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with + // zero coefficient; as a result, training is not needed. + // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of + // the fitIntercept + if (yMean == 0.0) { + logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + + s"and the intercept will all be zero; as a result, training is not needed.") + } else { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + } + if (handlePersistence) instances.unpersist() + val coefficients = Vectors.sparse(numFeatures, Seq()) + val intercept = yMean + + val model = new LinearRegressionModel(uid, coefficients, intercept) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + model, + Array(0D), + $(featuresCol), + Array(0D)) + return copyValues(model.setSummary(trainingSummary)) + } else { + require($(regParam) == 0.0, "The standard deviation of the label is zero. " + + "Model cannot be regularized.") + logWarning(s"The standard deviation of the label is zero. " + + "Consider setting fitIntercept=true.") + } } + // if y is constant (rawYStd is zero), then y cannot be scaled. In this case + // setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm. + val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean) val featuresMean = featuresSummarizer.mean.toArray val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 273c882c2a47f..81fc6603ccfe6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -37,6 +37,8 @@ class LinearRegressionSuite @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @transient var datasetWithSparseFeature: DataFrame = _ @transient var datasetWithWeight: DataFrame = _ + @transient var datasetWithWeightConstantLabel: DataFrame = _ + @transient var datasetWithWeightZeroLabel: DataFrame = _ /* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML @@ -92,6 +94,29 @@ class LinearRegressionSuite Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2)) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(17, 17, 17, 17) + w <- c(1, 2, 3, 4) + df.const.label <- as.data.frame(cbind(A, b.const)) + */ + datasetWithWeightConstantLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + datasetWithWeightZeroLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) } test("params") { @@ -558,6 +583,86 @@ class LinearRegressionSuite } } + test("linear regression model with constant label") { + /* + R code: + for (formula in c(b.const ~ . -1, b.const ~ .)) { + model <- lm(formula, data=df.const.label, weights=w) + print(as.vector(coef(model))) + } + [1] -9.221298 3.394343 + [1] 17 0 0 + */ + val expected = Seq( + Vectors.dense(0.0, -9.221298, 3.394343), + Vectors.dense(17.0, 0.0, 0.0)) + + Seq("auto", "l-bfgs", "normal").foreach { solver => + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightConstantLabel) + val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0), + model1.coefficients(1)) + assert(actual1 ~== expected(idx) absTol 1e-4) + + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightZeroLabel) + val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0), + model2.coefficients(1)) + assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4) + idx += 1 + } + } + } + + test("regularized linear regression through origin with constant label") { + // The problem is ill-defined if fitIntercept=false, regParam is non-zero. + // An exception is thrown in this case. + Seq("auto", "l-bfgs", "normal").foreach { solver => + for (standardization <- Seq(false, true)) { + val model = new LinearRegression().setFitIntercept(false) + .setRegParam(0.1).setStandardization(standardization).setSolver(solver) + intercept[IllegalArgumentException] { + model.fit(datasetWithWeightConstantLabel) + } + } + } + } + + test("linear regression with l-bfgs when training is not needed") { + // When label is constant, l-bfgs solver returns results without training. + // There are two possibilities: If the label is non-zero but constant, + // and fitIntercept is true, then the model return yMean as intercept without training. + // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so + // no training is needed. + for (fitIntercept <- Seq(false, true)) { + for (standardization <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setStandardization(standardization) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightConstantLabel) + if (fitIntercept) { + assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightZeroLabel) + assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + } + } + test("linear regression model training summary") { Seq("auto", "l-bfgs", "normal").foreach { solver => val trainer = new LinearRegression().setSolver(solver) From 335f10edad8c759bad3dbd0660ed4dd5d70ddd8b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 2 Feb 2016 21:13:54 -0800 Subject: [PATCH 065/107] [SPARK-7997][CORE] Add rpcEnv.awaitTermination() back to SparkEnv `rpcEnv.awaitTermination()` was not added in #10854 because some Streaming Python tests hung forever. This patch fixed the hung issue and added rpcEnv.awaitTermination() back to SparkEnv. Previously, Streaming Kafka Python tests shutdowns the zookeeper server before stopping StreamingContext. Then when stopping StreamingContext, KafkaReceiver may be hung due to https://issues.apache.org/jira/browse/KAFKA-601, hence, some thread of RpcEnv's Dispatcher cannot exit and rpcEnv.awaitTermination is hung.The patch just changed the shutdown order to fix it. Author: Shixiong Zhu Closes #11031 from zsxwing/awaitTermination. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 1 + python/pyspark/streaming/tests.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 12c7b2048a8c8..9461afdc54124 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -91,6 +91,7 @@ class SparkEnv ( metricsSystem.stop() outputCommitCoordinator.stop() rpcEnv.shutdown() + rpcEnv.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 24b812615cbb4..b33e8252a7d32 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1013,12 +1013,12 @@ def setUp(self): self._kafkaTestUtils.setup() def tearDown(self): + super(KafkaStreamTests, self).tearDown() + if self._kafkaTestUtils is not None: self._kafkaTestUtils.teardown() self._kafkaTestUtils = None - super(KafkaStreamTests, self).tearDown() - def _randomTopic(self): return "topic-%d" % random.randint(0, 10000) From e86f8f63bfa3c15659b94e831b853b1bc9ddae32 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 22:13:10 -0800 Subject: [PATCH 066/107] [SPARK-13147] [SQL] improve readability of generated code 1. try to avoid the suffix (unique id) 2. remove the comment if there is no code generated. 3. re-arrange the order of functions 4. trop the new line for inlined blocks. Author: Davies Liu Closes #11032 from davies/better_suffix. --- .../sql/catalyst/expressions/Expression.scala | 8 +++-- .../expressions/codegen/CodeGenerator.scala | 27 ++++++++++------ .../expressions/complexTypeExtractors.scala | 31 +++++++++++-------- .../sql/execution/WholeStageCodegen.scala | 13 ++++---- .../aggregate/TungstenAggregate.scala | 14 ++++----- .../spark/sql/execution/basicOperators.scala | 7 ++++- .../BenchmarkWholeStageCodegen.scala | 2 +- 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 353fb92581d3b..c73b2f8f2a316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] { val value = ctx.freshName("value") val ve = ExprCode("", isNull, value) ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + if (ve.code != "") { + // Add `this` in the comment. + ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + } else { + ve + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a30aba16170a9..63e19564dd861 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -156,7 +156,11 @@ class CodegenContext { /** The variable name of the input row in generated code. */ final var INPUT_ROW = "i" - private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * The map from a variable name to it's next ID. + */ + private val freshNameIds = new mutable.HashMap[String, Int] + freshNameIds += INPUT_ROW -> 1 /** * A prefix used to generate fresh name. @@ -164,16 +168,21 @@ class CodegenContext { var freshNamePrefix = "" /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) + * Returns a term name that is unique within this instance of a `CodegenContext`. */ - def freshName(name: String): String = { - if (freshNamePrefix == "") { - s"$name${curId.getAndIncrement}" + def freshName(name: String): String = synchronized { + val fullName = if (freshNamePrefix == "") { + name + } else { + s"${freshNamePrefix}_$name" + } + if (freshNameIds.contains(fullName)) { + val id = freshNameIds(fullName) + freshNameIds(fullName) = id + 1 + s"$fullName$id" } else { - s"${freshNamePrefix}_$name${curId.getAndIncrement}" + freshNameIds += fullName -> 1 + fullName } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9f2f82d68cca0..6b24fae9f3f1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -173,22 +173,26 @@ case class GetArrayStructFields( override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { + val n = ctx.freshName("n") + val values = ctx.freshName("values") + val j = ctx.freshName("j") + val row = ctx.freshName("row") s""" - final int n = $eval.numElements(); - final Object[] values = new Object[n]; - for (int j = 0; j < n; j++) { - if ($eval.isNullAt(j)) { - values[j] = null; + final int $n = $eval.numElements(); + final Object[] $values = new Object[$n]; + for (int $j = 0; $j < $n; $j++) { + if ($eval.isNullAt($j)) { + $values[$j] = null; } else { - final InternalRow row = $eval.getStruct(j, $numFields); - if (row.isNullAt($ordinal)) { - values[j] = null; + final InternalRow $row = $eval.getStruct($j, $numFields); + if ($row.isNullAt($ordinal)) { + $values[$j] = null; } else { - values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; } } } - ${ev.value} = new $arrayClass(values); + ${ev.value} = new $arrayClass($values); """ }) } @@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") s""" - final int index = (int) $eval2; - if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) { + final int $index = (int) $eval2; + if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, "index")}; + ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; } """ }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 02b0f423ed438..14754969072f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { s""" | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); - | ${columns.map(_.code).mkString("\n")} - | ${consume(ctx, columns)} + | ${columns.map(_.code).mkString("\n").trim} + | ${consume(ctx, columns).trim} | } """.stripMargin } @@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) private Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public GeneratedIterator(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} + this.references = references; + ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + protected void processNext() throws java.io.IOException { - $code + ${code.trim} } } """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index f61db8594dab2..d0244770613d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -211,9 +211,9 @@ case class TungstenAggregate( | $doAgg(); | | // output the result - | $genResult + | ${genResult.trim} | - | ${consume(ctx, resultVars)} + | ${consume(ctx, resultVars).trim} | } """.stripMargin } @@ -242,9 +242,9 @@ case class TungstenAggregate( } s""" | // do aggregate - | ${aggVals.map(_.code).mkString("\n")} + | ${aggVals.map(_.code).mkString("\n").trim} | // update aggregation buffer - | ${updates.mkString("")} + | ${updates.mkString("\n").trim} """.stripMargin } @@ -523,7 +523,7 @@ case class TungstenAggregate( // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" // generate grouping key - ${keyCode.code} + ${keyCode.code.trim} UnsafeRow $buffer = null; if ($checkFallback) { // try to get the buffer from hash map @@ -547,9 +547,9 @@ case class TungstenAggregate( $incCounter // evaluate aggregate function - ${evals.map(_.code).mkString("\n")} + ${evals.map(_.code).mkString("\n").trim} // update aggregate buffer - ${updates.mkString("\n")} + ${updates.mkString("\n").trim} """ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fd81531c9316a..ae4422195cc4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit BindReferences.bindReference(condition, child.output)) ctx.currentVars = input val eval = expr.gen(ctx) + val nullCheck = if (expr.nullable) { + s"!${eval.isNull} &&" + } else { + s"" + } s""" | ${eval.code} - | if (!${eval.isNull} && ${eval.value}) { + | if ($nullCheck ${eval.value}) { | ${consume(ctx, ctx.currentVars)} | } """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 1ccf0e3d0656c..ec2b9ab2cbad5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { // These benchmark are skipped in normal build ignore("benchmark") { // testWholeStage(200 << 20) - // testStddev(20 << 20) + // testStatFunctions(20 << 20) // testAggregateWithKey(20 << 20) // testBytesToBytesMap(1024 * 1024 * 50) } From 138c300f97d29cb0d04a70bea98a8a0c0548318a Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 2 Feb 2016 22:22:50 -0800 Subject: [PATCH 067/107] [SPARK-12957][SQL] Initial support for constraint propagation in SparkSQL MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on the semantics of a query, we can derive a number of data constraints on output of each (logical or physical) operator. For instance, if a filter defines `‘a > 10`, we know that the output data of this filter satisfies 2 constraints: 1. `‘a > 10` 2. `isNotNull(‘a)` This PR proposes a possible way of keeping track of these constraints and propagating them in the logical plan, which can then help us build more advanced optimizations (such as pruning redundant filters, optimizing joins, among others). We define constraints as a set of (implicitly conjunctive) expressions. For e.g., if a filter operator has constraints = `Set(‘a > 10, ‘b < 100)`, it’s implied that the outputs satisfy both individual constraints (i.e., `‘a > 10` AND `‘b < 100`). Design Document: https://docs.google.com/a/databricks.com/document/d/1WQRgDurUBV9Y6CWOBS75PQIqJwT-6WftVa18xzm7nCo/edit?usp=sharing Author: Sameer Agarwal Closes #10844 from sameeragarwal/constraints. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 55 +++++- .../catalyst/plans/logical/LogicalPlan.scala | 2 + .../plans/logical/basicOperators.scala | 79 +++++++- .../plans/ConstraintPropagationSuite.scala | 173 ++++++++++++++++++ 4 files changed, 302 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b43b7ee71e7aa..05f5bdbfc0769 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -26,6 +26,56 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def output: Seq[Attribute] + /** + * Extracts the relevant constraints from a given set of constraints based on the attributes that + * appear in the [[outputSet]]. + */ + protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints + .union(constructIsNotNullConstraints(constraints)) + .filter(constraint => + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + } + + /** + * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions. + * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form + * `isNotNull(a)` + */ + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // Currently we only propagate constraints if the condition consists of equality + // and ranges. For all other cases, we return an empty set of constraints + constraints.map { + case EqualTo(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _.toSet) + } + + /** + * A sequence of expressions that describes the data property of the output rows of this + * operator. For example, if the output of this operator is column `a`, an example `constraints` + * can be `Set(a > 10, a < 20)`. + */ + lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints) + + /** + * This method can be overridden by any child class of QueryPlan to specify a set of constraints + * based on the given operator's constraint propagation logic. These constraints are then + * canonicalized and filtered automatically to contain only those attributes that appear in the + * [[outputSet]] + */ + protected def validConstraints: Set[Expression] = Set.empty + /** * Returns the set of attributes that are output by this node. */ @@ -59,6 +109,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { @@ -67,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { @@ -99,6 +151,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. * @return */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d859551f8c52..d8944a424156e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -305,6 +305,8 @@ abstract class UnaryNode extends LogicalPlan { def child: LogicalPlan override def children: Seq[LogicalPlan] = child :: Nil + + override protected def validConstraints: Set[Expression] = child.constraints } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 16f4b355b1b6c..8150ff8434762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -87,11 +87,27 @@ case class Generate( } } -case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { +case class Filter(condition: Expression, child: LogicalPlan) + extends UnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output + + override protected def validConstraints: Set[Expression] = { + child.constraints.union(splitConjunctivePredicates(condition).toSet) + } } -abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + + protected def leftConstraints: Set[Expression] = left.constraints + + protected def rightConstraints: Set[Expression] = { + require(left.output.size == right.output.size) + val attributeRewrites = AttributeMap(right.output.zip(left.output)) + right.constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } +} private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) @@ -106,6 +122,10 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + override protected def validConstraints: Set[Expression] = { + leftConstraints.union(rightConstraints) + } + // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. override lazy val resolved: Boolean = @@ -119,6 +139,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + override protected def validConstraints: Set[Expression] = leftConstraints + override lazy val resolved: Boolean = childrenResolved && left.output.length == right.output.length && @@ -157,13 +179,36 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { val sizeInBytes = children.map(_.statistics.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } + + /** + * Maps the constraints containing a given (original) sequence of attributes to those with a + * given (reference) sequence of attributes. Given the nature of union, we expect that the + * mapping between the original and reference sequences are symmetric. + */ + private def rewriteConstraints( + reference: Seq[Attribute], + original: Seq[Attribute], + constraints: Set[Expression]): Set[Expression] = { + require(reference.size == original.size) + val attributeRewrites = AttributeMap(original.zip(reference)) + constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } + + override protected def validConstraints: Set[Expression] = { + children + .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) + .reduce(_ intersect _) + } } case class Join( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryNode with PredicateHelper { override def output: Seq[Attribute] = { joinType match { @@ -180,6 +225,28 @@ case class Join( } } + override protected def validConstraints: Set[Expression] = { + joinType match { + case Inner if condition.isDefined => + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(condition.get).toSet) + case LeftSemi if condition.isDefined => + left.constraints + .union(splitConjunctivePredicates(condition.get).toSet) + case Inner => + left.constraints.union(right.constraints) + case LeftSemi => + left.constraints + case LeftOuter => + left.constraints + case RightOuter => + right.constraints + case FullOuter => + Set.empty[Expression] + } + } + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala new file mode 100644 index 0000000000000..b5cf91394d910 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class ConstraintPropagationSuite extends SparkFunSuite { + + private def resolveColumn(tr: LocalRelation, columnName: String): Expression = + tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + + private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { + val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) + val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _)) + if (missing.nonEmpty || extra.nonEmpty) { + fail( + s""" + |== FAIL: Constraints do not match === + |Found: ${found.mkString(",")} + |Expected: ${expected.mkString(",")} + |== Result == + |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")} + |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")} + """.stripMargin) + } + } + + test("propagating constraints in filters") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) + + verifyConstraints(tr + .where('a.attr > 10) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + IsNotNull(resolveColumn(tr, "a")))) + + verifyConstraints(tr + .where('a.attr > 10) + .select('c.attr, 'a.attr) + .where('c.attr < 100) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "c") < 100, + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c")))) + } + + test("propagating constraints in union") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('d.int, 'e.int, 'f.int) + val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + + assert(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('e.attr > 10) + .unionAll(tr3.where('i.attr > 10))) + .analyze.constraints.isEmpty) + + verifyConstraints(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('d.attr > 10) + .unionAll(tr3.where('g.attr > 10))) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) + } + + test("propagating constraints in intersect") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr1 + .where('a.attr > 10) + .intersect(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + resolveColumn(tr1, "b") < 100, + IsNotNull(resolveColumn(tr1, "a")), + IsNotNull(resolveColumn(tr1, "b")))) + } + + test("propagating constraints in except") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + verifyConstraints(tr1 + .where('a.attr > 10) + .except(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) + } + + test("propagating constraints in inner join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + tr1.resolveQuoted("a", caseInsensitiveResolution).get === + tr2.resolveQuoted("a", caseInsensitiveResolution).get, + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-semi join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in right-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + } + + test("propagating constraints in full-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + assert(tr1.where('a.attr > 10) + .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints.isEmpty) + } +} From e9eb248edfa81d75f99c9afc2063e6b3d9ee7392 Mon Sep 17 00:00:00 2001 From: Mario Briggs Date: Wed, 3 Feb 2016 09:50:28 -0800 Subject: [PATCH 068/107] [SPARK-12739][STREAMING] Details of batch in Streaming tab uses two Duration columns I have clearly prefix the two 'Duration' columns in 'Details of Batch' Streaming tab as 'Output Op Duration' and 'Job Duration' Author: Mario Briggs Author: mariobriggs Closes #11022 from mariobriggs/spark-12739. --- .../main/scala/org/apache/spark/streaming/ui/BatchPage.scala | 4 ++-- .../scala/org/apache/spark/streaming/UISeleniumSuite.scala | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 7635f79a3d2d1..81de07f933f8a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -37,10 +37,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private def columns: Seq[Node] = { Output Op Id Description - Duration + Output Op Duration Status Job Id - Duration + Job Duration Stages: Succeeded/Total Tasks (for all stages): Succeeded/Total Error diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index c4ecebcacf3c8..96dd4757be855 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -143,8 +143,9 @@ class UISeleniumSuite summaryText should contain ("Total delay:") findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be { - List("Output Op Id", "Description", "Duration", "Status", "Job Id", "Duration", - "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error") + List("Output Op Id", "Description", "Output Op Duration", "Status", "Job Id", + "Job Duration", "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", + "Error") } // Check we have 2 output op ids From c4feec26eb677bfd3bfac38e5e28eae05279956e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Feb 2016 10:38:53 -0800 Subject: [PATCH 069/107] [SPARK-12798] [SQL] generated BroadcastHashJoin A row from stream side could match multiple rows on build side, the loop for these matched rows should not be interrupted when emitting a row, so we buffer the output rows in a linked list, check the termination condition on producer loop (for example, Range or Aggregate). Author: Davies Liu Closes #10989 from davies/gen_join. --- .../sql/execution/BufferedRowIterator.java | 30 ++++-- .../sql/execution/WholeStageCodegen.scala | 18 ++-- .../aggregate/TungstenAggregate.scala | 4 +- .../spark/sql/execution/basicOperators.scala | 2 + .../execution/joins/BroadcastHashJoin.scala | 92 ++++++++++++++++++- .../BenchmarkWholeStageCodegen.scala | 28 +++++- .../execution/WholeStageCodegenSuite.scala | 15 ++- 7 files changed, 169 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index 6acf70dbbad0c..ea20115770f79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -18,9 +18,11 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.LinkedList; import scala.collection.Iterator; +import org.apache.spark.TaskContext; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -31,28 +33,42 @@ * TODO: replaced it by batched columnar format. */ public class BufferedRowIterator { - protected InternalRow currentRow; + protected LinkedList currentRows = new LinkedList<>(); protected Iterator input; // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); public boolean hasNext() throws IOException { - if (currentRow == null) { + if (currentRows.isEmpty()) { processNext(); } - return currentRow != null; + return !currentRows.isEmpty(); } public InternalRow next() { - InternalRow r = currentRow; - currentRow = null; - return r; + return currentRows.remove(); } public void setInput(Iterator iter) { input = iter; } + /** + * Returns whether `processNext()` should stop processing next row from `input` or not. + * + * If it returns true, the caller should exit the loop (return from processNext()). + */ + protected boolean shouldStop() { + return !currentRows.isEmpty(); + } + + /** + * Increase the peak execution memory for current task. + */ + protected void incPeakExecutionMemory(long size) { + TaskContext.get().taskMetrics().incPeakExecutionMemory(size); + } + /** * Processes the input until have a row as output (currentRow). * @@ -60,7 +76,7 @@ public void setInput(Iterator iter) { */ protected void processNext() throws IOException { if (input.hasNext()) { - currentRow = input.next(); + currentRows.add(input.next()); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 14754969072f9..131efea20f31e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} import org.apache.spark.util.Utils /** @@ -172,6 +173,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} + | if (shouldStop()) { + | return; + | } | } """.stripMargin } @@ -283,8 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) if (row != null) { // There is an UnsafeRow already s""" - | currentRow = $row; - | return; + | currentRows.add($row.copy()); """.stripMargin } else { assert(input != null) @@ -297,14 +300,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" | ${code.code.trim} - | currentRow = ${code.value}; - | return; + | currentRows.add(${code.value}.copy()); """.stripMargin } else { // There is no columns s""" - | currentRow = unsafeRow; - | return; + | currentRows.add(unsafeRow); """.stripMargin } } @@ -371,6 +372,11 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { + // The build side can't be compiled together + case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + b.copy(left = apply(left)) + case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + b.copy(right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively inputs += input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d0244770613d3..9d9f14f2dd014 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -471,6 +471,8 @@ case class TungstenAggregate( UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); $outputCode + + if (shouldStop()) return; } $iterTerm.close(); @@ -480,7 +482,7 @@ case class TungstenAggregate( """ } - private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key ctx.currentVars = input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ae4422195cc4b..6e51c4d84824a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -237,6 +237,8 @@ case class Range( | $overflow = true; | } | ${consume(ctx, Seq(ev))} + | + | if (shouldStop()) return; | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 04640711d99d0..8b275e886c46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -42,7 +45,7 @@ case class BroadcastHashJoin( condition: Option[Expression], left: SparkPlan, right: SparkPlan) - extends BinaryNode with HashJoin { + extends BinaryNode with HashJoin with CodegenSupport { override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), @@ -117,6 +120,87 @@ case class BroadcastHashJoin( hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } + + // the term for hash relation + private var relationTerm: String = _ + + override def upstream(): RDD[InternalRow] = { + streamedPlan.asInstanceOf[CodegenSupport].upstream() + } + + override def doProduce(ctx: CodegenContext): String = { + // create a name for HashRelation + val broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) + relationTerm = ctx.freshName("relation") + // TODO: create specialized HashRelation for single join key + val clsName = classOf[UnsafeHashedRelation].getName + ctx.addMutableState(clsName, relationTerm, + s""" + | $relationTerm = ($clsName) $broadcast.value(); + | incPeakExecutionMemory($relationTerm.getUnsafeSize()); + """.stripMargin) + + s""" + | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} + """.stripMargin + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // generate the key as UnsafeRow + ctx.currentVars = input + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val keyTerm = keyVal.value + val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + + // find the matches from HashedRelation + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val row = ctx.freshName("row") + + // create variables for output + ctx.currentVars = null + ctx.INPUT_ROW = row + val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).gen(ctx) + } + val resultVars = buildSide match { + case BuildLeft => buildColumns ++ input + case BuildRight => input ++ buildColumns + } + + val ouputCode = if (condition.isDefined) { + // filter the output via condition + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + | ${ev.code} + | if (!${ev.isNull} && ${ev.value}) { + | ${consume(ctx, resultVars)} + | } + """.stripMargin + } else { + consume(ctx, resultVars) + } + + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); + | if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $row = (UnsafeRow) $matches.apply($i); + | ${buildColumns.map(_.code).mkString("\n")} + | $ouputCode + | } + | } + """.stripMargin + } } object BroadcastHashJoin { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index ec2b9ab2cbad5..15ba77353109c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.functions._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap @@ -130,6 +131,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } + def testBroadcastHashJoin(values: Int): Unit = { + val benchmark = new Benchmark("BroadcastHashJoin", values) + + val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) + + benchmark.addCase("BroadcastHashJoin w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + } + benchmark.addCase(s"BroadcastHashJoin w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X + BroadcastHashJoin w codegen 1028.40 10.20 2.97 X + */ + benchmark.run() + } + def testBytesToBytesMap(values: Int): Unit = { val benchmark = new Benchmark("BytesToBytesMap", values) @@ -201,6 +226,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { // testWholeStage(200 << 20) // testStatFunctions(20 << 20) // testAggregateWithKey(20 << 20) - // testBytesToBytesMap(1024 * 1024 * 50) + // testBytesToBytesMap(50 << 20) + // testBroadcastHashJoin(10 << 20) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index c2516509dfbbf..9350205d791d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.functions.{avg, col, max} +import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { @@ -56,4 +58,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) } + + test("BroadcastHashJoin should be included in WholeStageCodegen") { + val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val schema = new StructType().add("k", IntegerType).add("v", StringType) + val smallDF = sqlContext.createDataFrame(rdd, schema) + val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + assert(df.queryExecution.executedPlan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) + assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) + } } From 9dd2741ebe5f9b5fa0a3b0e9c594d0e94b6226f9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 3 Feb 2016 12:31:30 -0800 Subject: [PATCH 070/107] [SPARK-13157] [SQL] Support any kind of input for SQL commands. The ```SparkSqlLexer``` currently swallows characters which have not been defined in the grammar. This causes problems with SQL commands, such as: ```add jar file:///tmp/ab/TestUDTF.jar```. In this example the `````` is swallowed. This PR adds an extra Lexer rule to handle such input, and makes a tiny modification to the ```ASTNode```. cc davies liancheng Author: Herman van Hovell Closes #11052 from hvanhovell/SPARK-13157. --- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 4 ++ .../spark/sql/catalyst/parser/ASTNode.scala | 4 +- .../sql/catalyst/parser/ASTNodeSuite.scala | 38 +++++++++++++++++++ .../HiveThriftServer2Suites.scala | 6 +-- 4 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index e930caa291d4f..1d07a27353dcb 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -483,3 +483,7 @@ COMMENT { $channel=HIDDEN; } ; +/* Prevent that the lexer swallows unknown characters. */ +ANY + :. + ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala index ec9812414e19f..28f7b10ed6a59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala @@ -58,12 +58,12 @@ case class ASTNode( override val origin: Origin = Origin(Some(line), Some(positionInLine)) /** Source text. */ - lazy val source: String = stream.toString(startIndex, stopIndex) + lazy val source: String = stream.toOriginalString(startIndex, stopIndex) /** Get the source text that remains after this token. */ lazy val remainder: String = { stream.fill() - stream.toString(stopIndex + 1, stream.size() - 1).trim() + stream.toOriginalString(stopIndex + 1, stream.size() - 1).trim() } def text: String = token.getText diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala new file mode 100644 index 0000000000000..8b05f9e33d69e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +class ASTNodeSuite extends SparkFunSuite { + test("SPARK-13157 - remainder must return all input chars") { + val inputs = Seq( + ("add jar", "file:///tmp/ab/TestUDTF.jar"), + ("add jar", "file:///tmp/a@b/TestUDTF.jar"), + ("add jar", "c:\\windows32\\TestUDTF.jar"), + ("add jar", "some \nbad\t\tfile\r\n.\njar"), + ("ADD JAR", "@*#&@(!#@$^*!@^@#(*!@#"), + ("SET", "foo=bar"), + ("SET", "foo*)(@#^*@&!#^=bar") + ) + inputs.foreach { + case (command, arguments) => + val node = ParseDriver.parsePlan(s"$command $arguments", null) + assert(node.remainder === arguments) + } + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 9860e40fe8546..ba3b26e1b7d49 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -488,8 +488,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } - // TODO: enable this - ignore("SPARK-11595 ADD JAR with input path having URL scheme") { + test("SPARK-11595 ADD JAR with input path having URL scheme") { withJdbcStatement { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" @@ -547,8 +546,7 @@ class SingleSessionSuite extends HiveThriftJdbcTest { override protected def extraConf: Seq[String] = "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil - // TODO: enable this - ignore("test single session") { + test("test single session") { withMultipleConnectionJdbcStatement( { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" From 3221eddb8f9728f65c579969a3a88baeeb7577a9 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 3 Feb 2016 15:53:10 -0800 Subject: [PATCH 071/107] [SPARK-3611][WEB UI] Show number of cores for each executor in application web UI Added a Cores column in the Executors UI Author: Alex Bozarth Closes #11039 from ajbozarth/spark3611. --- .../main/scala/org/apache/spark/status/api/v1/api.scala | 1 + .../scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 7 +++++++ .../main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala | 5 +++-- .../executor_list_json_expectation.json | 1 + 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 2b0079f5fd62e..d116e68c17f18 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -57,6 +57,7 @@ class ExecutorSummary private[spark]( val rddBlocks: Int, val memoryUsed: Long, val diskUsed: Long, + val totalCores: Int, val maxTasks: Int, val activeTasks: Int, val failedTasks: Int, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index e36b96b3e6978..e1f754999912b 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -75,6 +75,7 @@ private[ui] class ExecutorsPage( RDD Blocks Storage Memory Disk Used + Cores Active Tasks Failed Tasks Complete Tasks @@ -131,6 +132,7 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} + {info.totalCores} {taskData(info.maxTasks, info.activeTasks, info.failedTasks, info.completedTasks, info.totalTasks, info.totalDuration, info.totalGCTime)} @@ -174,6 +176,7 @@ private[ui] class ExecutorsPage( val maximumMemory = execInfo.map(_.maxMemory).sum val memoryUsed = execInfo.map(_.memoryUsed).sum val diskUsed = execInfo.map(_.diskUsed).sum + val totalCores = execInfo.map(_.totalCores).sum val totalInputBytes = execInfo.map(_.totalInputBytes).sum val totalShuffleRead = execInfo.map(_.totalShuffleRead).sum val totalShuffleWrite = execInfo.map(_.totalShuffleWrite).sum @@ -188,6 +191,7 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} + {totalCores} {taskData(execInfo.map(_.maxTasks).sum, execInfo.map(_.activeTasks).sum, execInfo.map(_.failedTasks).sum, @@ -211,6 +215,7 @@ private[ui] class ExecutorsPage( RDD Blocks Storage Memory Disk Used + Cores Active Tasks Failed Tasks Complete Tasks @@ -305,6 +310,7 @@ private[spark] object ExecutorsPage { val memUsed = status.memUsed val maxMem = status.maxMem val diskUsed = status.diskUsed + val totalCores = listener.executorToTotalCores.getOrElse(execId, 0) val maxTasks = listener.executorToTasksMax.getOrElse(execId, 0) val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0) val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) @@ -323,6 +329,7 @@ private[spark] object ExecutorsPage { rddBlocks, memUsed, diskUsed, + totalCores, maxTasks, activeTasks, failedTasks, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index a9e926b158780..dcfebe92ed805 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -45,6 +45,7 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec @DeveloperApi class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { + val executorToTotalCores = HashMap[String, Int]() val executorToTasksMax = HashMap[String, Int]() val executorToTasksActive = HashMap[String, Int]() val executorToTasksComplete = HashMap[String, Int]() @@ -65,8 +66,8 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap - executorToTasksMax(eid) = - executorAdded.executorInfo.totalCores / conf.getInt("spark.task.cpus", 1) + executorToTotalCores(eid) = executorAdded.executorInfo.totalCores + executorToTasksMax(eid) = executorToTotalCores(eid) / conf.getInt("spark.task.cpus", 1) executorIdToData(eid) = ExecutorUIData(executorAdded.time) } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 94f8aeac55b5d..9d5d224e55176 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -4,6 +4,7 @@ "rddBlocks" : 8, "memoryUsed" : 28000128, "diskUsed" : 0, + "totalCores" : 0, "maxTasks" : 0, "activeTasks" : 0, "failedTasks" : 1, From 915a75398ecbccdbf9a1e07333104c857ae1ce5e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Feb 2016 16:10:11 -0800 Subject: [PATCH 072/107] [SPARK-13166][SQL] Remove DataStreamReader/Writer They seem redundant and we can simply use DataFrameReader/Writer. The new usage looks like: ```scala val df = sqlContext.read.stream("...") val handle = df.write.stream("...") handle.stop() ``` Author: Reynold Xin Closes #11062 from rxin/SPARK-13166. --- .../org/apache/spark/sql/DataFrame.scala | 10 +- .../apache/spark/sql/DataFrameReader.scala | 29 +++- .../apache/spark/sql/DataFrameWriter.scala | 36 ++++- .../apache/spark/sql/DataStreamReader.scala | 127 ----------------- .../apache/spark/sql/DataStreamWriter.scala | 134 ------------------ .../org/apache/spark/sql/SQLContext.scala | 11 +- .../datasources/ResolvedDataSource.scala | 1 - .../sql/streaming/DataStreamReaderSuite.scala | 53 ++++--- 8 files changed, 86 insertions(+), 315 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6de17e5924d04..84203bbfef66a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1682,7 +1682,7 @@ class DataFrame private[sql]( /** * :: Experimental :: - * Interface for saving the content of the [[DataFrame]] out into external storage. + * Interface for saving the content of the [[DataFrame]] out into external storage or streams. * * @group output * @since 1.4.0 @@ -1690,14 +1690,6 @@ class DataFrame private[sql]( @Experimental def write: DataFrameWriter = new DataFrameWriter(this) - /** - * :: Experimental :: - * Interface for starting a streaming query that will continually output results to the specified - * external sink as new data arrives. - */ - @Experimental - def streamTo: DataStreamWriter = new DataStreamWriter(this) - /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2e0c6c7df967e..a58643a5ba154 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,17 +29,17 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystQl} import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types.StructType /** * :: Experimental :: * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, - * key-value stores, etc). Use [[SQLContext.read]] to access this. + * key-value stores, etc) or data streams. Use [[SQLContext.read]] to access this. * * @since 1.4.0 */ @@ -136,6 +136,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() } + /** + * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path + * (e.g. external key-value stores). + * + * @since 2.0.0 + */ + def stream(): DataFrame = { + val resolved = ResolvedDataSource.createSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + providerName = source, + options = extraOptions.toMap) + DataFrame(sqlContext, StreamingRelation(resolved)) + } + + /** + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + def stream(path: String): DataFrame = { + option("path", path).stream() + } + /** * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table and connection properties. @@ -165,7 +189,6 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. - * * @since 1.4.0 */ def jdbc( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 12eb2393634a9..80447fefe1f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -22,17 +22,18 @@ import java.util.Properties import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.sources.HadoopFsRelation /** * :: Experimental :: * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, - * key-value stores, etc). Use [[DataFrame.write]] to access this. + * key-value stores, etc) or data streams. Use [[DataFrame.write]] to access this. * * @since 1.4.0 */ @@ -183,6 +184,34 @@ final class DataFrameWriter private[sql](df: DataFrame) { df) } + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def stream(path: String): ContinuousQuery = { + option("path", path).stream() + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def stream(): ContinuousQuery = { + val sink = ResolvedDataSource.createSink( + df.sqlContext, + source, + extraOptions.toMap, + normalizedParCols.getOrElse(Nil)) + + new StreamExecution(df.sqlContext, df.logicalPlan, sink) + } + /** * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. @@ -255,7 +284,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * The given column name may not be equal to any of the existing column names if we were in - * case-insensitive context. Normalize the given column name to the real one so that we don't + * case-insensitive context. Normalize the given column name to the real one so that we don't * need to care about case sensitivity afterwards. */ private def normalize(columnName: String, columnType: String): String = { @@ -339,7 +368,6 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. - * * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala deleted file mode 100644 index 2febc93fa49d4..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ - -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.datasources.ResolvedDataSource -import org.apache.spark.sql.execution.streaming.StreamingRelation -import org.apache.spark.sql.types.StructType - -/** - * :: Experimental :: - * An interface to reading streaming data. Use `sqlContext.streamFrom` to access these methods. - * - * {{{ - * val df = sqlContext.streamFrom - * .format("...") - * .open() - * }}} - */ -@Experimental -class DataStreamReader private[sql](sqlContext: SQLContext) extends Logging { - - /** - * Specifies the input data source format. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamReader = { - this.source = source - this - } - - /** - * Specifies the input schema. Some data streams (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data stream can - * skip the schema inference step, and thus speed up data reading. - * - * @since 2.0.0 - */ - def schema(schema: StructType): DataStreamReader = { - this.userSpecifiedSchema = Option(schema) - this - } - - /** - * Adds an input option for the underlying data stream. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamReader = { - this.extraOptions += (key -> value) - this - } - - /** - * (Scala-specific) Adds input options for the underlying data stream. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamReader = { - this.extraOptions ++= options - this - } - - /** - * Adds input options for the underlying data stream. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamReader = { - this.options(options.asScala) - this - } - - /** - * Loads streaming input in as a [[DataFrame]], for data streams that don't require a path (e.g. - * external key-value stores). - * - * @since 2.0.0 - */ - def open(): DataFrame = { - val resolved = ResolvedDataSource.createSource( - sqlContext, - userSpecifiedSchema = userSpecifiedSchema, - providerName = source, - options = extraOptions.toMap) - DataFrame(sqlContext, StreamingRelation(resolved)) - } - - /** - * Loads input in as a [[DataFrame]], for data streams that read from some path. - * - * @since 2.0.0 - */ - def open(path: String): DataFrame = { - option("path", path).open() - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = sqlContext.conf.defaultDataSourceName - - private var userSpecifiedSchema: Option[StructType] = None - - private var extraOptions = new scala.collection.mutable.HashMap[String, String] - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala deleted file mode 100644 index b325d48fcbbb1..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.datasources.ResolvedDataSource -import org.apache.spark.sql.execution.streaming.StreamExecution - -/** - * :: Experimental :: - * Interface used to start a streaming query query execution. - * - * @since 2.0.0 - */ -@Experimental -final class DataStreamWriter private[sql](df: DataFrame) { - - /** - * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamWriter = { - this.source = source - this - } - - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamWriter = { - this.extraOptions += (key -> value) - this - } - - /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamWriter = { - this.extraOptions ++= options - this - } - - /** - * Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamWriter = { - this.options(options.asScala) - this - } - - /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme.\ - * @since 2.0.0 - */ - @scala.annotation.varargs - def partitionBy(colNames: String*): DataStreamWriter = { - this.partitioningColumns = colNames - this - } - - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * @since 2.0.0 - */ - def start(path: String): ContinuousQuery = { - this.extraOptions += ("path" -> path) - start() - } - - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ - def start(): ContinuousQuery = { - val sink = ResolvedDataSource.createSink( - df.sqlContext, - source, - extraOptions.toMap, - normalizedParCols) - - new StreamExecution(df.sqlContext, df.logicalPlan, sink) - } - - private def normalizedParCols: Seq[String] = { - partitioningColumns.map { col => - df.logicalPlan.output - .map(_.name) - .find(df.sqlContext.analyzer.resolver(_, col)) - .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + - s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) - } - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = df.sqlContext.conf.defaultDataSourceName - - private var extraOptions = new scala.collection.mutable.HashMap[String, String] - - private var partitioningColumns: Seq[String] = Nil - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 13700be06828d..1661fdbec5326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -579,10 +579,9 @@ class SQLContext private[sql]( DataFrame(self, LocalRelation(attrSeq, rows.toSeq)) } - /** * :: Experimental :: - * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. + * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. * {{{ * sqlContext.read.parquet("/path/to/file.parquet") * sqlContext.read.schema(schema).json("/path/to/file.json") @@ -594,14 +593,6 @@ class SQLContext private[sql]( @Experimental def read: DataFrameReader = new DataFrameReader(this) - - /** - * :: Experimental :: - * Returns a [[DataStreamReader]] than can be used to access data continuously as it arrives. - */ - @Experimental - def streamFrom: DataStreamReader = new DataStreamReader(this) - /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index e3065ac5f87d2..7702f535ad2f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -122,7 +122,6 @@ object ResolvedDataSource extends Logging { provider.createSink(sqlContext, options, partitionColumns) } - /** Create a [[ResolvedDataSource]] for reading data in. */ def apply( sqlContext: SQLContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala index 1dab6ebf1bee9..b36b41cac9b4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala @@ -60,22 +60,22 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { import testImplicits._ test("resolve default source") { - sqlContext.streamFrom + sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open() - .streamTo + .stream() + .write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() } test("resolve full class") { - sqlContext.streamFrom + sqlContext.read .format("org.apache.spark.sql.streaming.test.DefaultSource") - .open() - .streamTo + .stream() + .write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() } @@ -83,12 +83,12 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { val map = new java.util.HashMap[String, String] map.put("opt3", "3") - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) - .open() + .stream() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") @@ -96,12 +96,12 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { LastOptions.parameters = null - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) - .start() + .stream() .stop() assert(LastOptions.parameters("opt1") == "1") @@ -110,54 +110,53 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { } test("partitioning") { - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open() + .stream() - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Nil) - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("a") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Seq("a")) - withSQLConf("spark.sql.caseSensitive" -> "false") { - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("A") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Seq("a")) } intercept[AnalysisException] { - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("b") - .start() + .stream() .stop() } } test("stream paths") { - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open("/test") + .stream("/test") assert(LastOptions.parameters("path") == "/test") LastOptions.parameters = null - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") - .start("/test") + .stream("/test") .stop() assert(LastOptions.parameters("path") == "/test") From de0914522fc5b2658959f9e2272b4e3162b14978 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Feb 2016 17:07:27 -0800 Subject: [PATCH 073/107] [SPARK-13131] [SQL] Use best and average time in benchmark Best time is stabler than average time, also added a column for nano seconds per row (which could be used to estimate contributions of each components in a query). Having best time and average time together for more information (we can see kind of variance). rate, time per row and relative are all calculated using best time. The result looks like this: ``` Intel(R) Core(TM) i7-4558U CPU 2.80GHz rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X ``` Author: Davies Liu Closes #11018 from davies/gen_bench. --- .../org/apache/spark/util/Benchmark.scala | 38 +++-- .../BenchmarkWholeStageCodegen.scala | 154 ++++++++---------- 2 files changed, 89 insertions(+), 103 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index d484cec7ae384..d1699f5c28655 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.commons.lang3.SystemUtils @@ -59,17 +60,21 @@ private[spark] class Benchmark( } println - val firstRate = results.head.avgRate + val firstBest = results.head.bestMs + val firstAvg = results.head.avgMs // The results are going to be processor specific so it is useful to include that. println(Benchmark.getProcessorName()) - printf("%-30s %16s %16s %14s\n", name + ":", "Avg Time(ms)", "Avg Rate(M/s)", "Relative Rate") - println("-------------------------------------------------------------------------------") - results.zip(benchmarks).foreach { r => - printf("%-30s %16s %16s %14s\n", - r._2.name, - "%10.2f" format r._1.avgMs, - "%10.2f" format r._1.avgRate, - "%6.2f X" format (r._1.avgRate / firstRate)) + printf("%-35s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)", + "Per Row(ns)", "Relative") + println("-----------------------------------------------------------------------------------" + + "--------") + results.zip(benchmarks).foreach { case (result, benchmark) => + printf("%-35s %16s %12s %13s %10s\n", + benchmark.name, + "%5.0f / %4.0f" format (result.bestMs, result.avgMs), + "%10.1f" format result.bestRate, + "%6.1f" format (1000 / result.bestRate), + "%3.1fX" format (firstBest / result.bestMs)) } println // scalastyle:on @@ -78,7 +83,7 @@ private[spark] class Benchmark( private[spark] object Benchmark { case class Case(name: String, fn: Int => Unit) - case class Result(avgMs: Double, avgRate: Double) + case class Result(avgMs: Double, bestRate: Double, bestMs: Double) /** * This should return a user helpful processor information. Getting at this depends on the OS. @@ -99,22 +104,27 @@ private[spark] object Benchmark { * the rate of the function. */ def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = { - var totalTime = 0L + val runTimes = ArrayBuffer[Long]() for (i <- 0 until iters + 1) { val start = System.nanoTime() f(i) val end = System.nanoTime() - if (i != 0) totalTime += end - start + val runTime = end - start + if (i > 0) { + runTimes += runTime + } if (outputPerIteration) { // scalastyle:off - println(s"Iteration $i took ${(end - start) / 1000} microseconds") + println(s"Iteration $i took ${runTime / 1000} microseconds") // scalastyle:on } } - Result(totalTime.toDouble / 1000000 / iters, num * iters / (totalTime.toDouble / 1000)) + val best = runTimes.min + val avg = runTimes.sum / iters + Result(avg / 1000000, num / (best / 1000), best / 1000000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 15ba77353109c..33d4976403d9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -34,54 +34,47 @@ import org.apache.spark.util.Benchmark */ class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + .set("spark.sql.shuffle.partitions", "1") lazy val sc = SparkContext.getOrCreate(conf) lazy val sqlContext = SQLContext.getOrCreate(sc) - def testWholeStage(values: Int): Unit = { - val benchmark = new Benchmark("rang/filter/aggregate", values) + def runBenchmark(name: String, values: Int)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, values) - benchmark.addCase("Without codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).filter("(id & 1) = 1").count() - } - - benchmark.addCase("With codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).filter("(id & 1) = 1").count() + Seq(false, true).foreach { enabled => + benchmark.addCase(s"$name codegen=$enabled") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", enabled.toString) + f + } } - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - Without codegen 7775.53 26.97 1.00 X - With codegen 342.15 612.94 22.73 X - */ benchmark.run() } - def testStatFunctions(values: Int): Unit = { - - val benchmark = new Benchmark("stat functions", values) - - benchmark.addCase("stddev w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + // These benchmark are skipped in normal build + ignore("range/filter/sum") { + val N = 500 << 20 + runBenchmark("rang/filter/sum", N) { + sqlContext.range(N).filter("(id & 1) = 1").groupBy().sum().collect() } + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X + rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X + */ + } - benchmark.addCase("stddev w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() - } + ignore("stat functions") { + val N = 100 << 20 - benchmark.addCase("kurtosis w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + runBenchmark("stddev", N) { + sqlContext.range(N).groupBy().agg("id" -> "stddev").collect() } - benchmark.addCase("kurtosis w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + runBenchmark("kurtosis", N) { + sqlContext.range(N).groupBy().agg("id" -> "kurtosis").collect() } @@ -99,64 +92,56 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Using DeclarativeAggregate: Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - stddev w/o codegen 989.22 21.20 1.00 X - stddev w codegen 352.35 59.52 2.81 X - kurtosis w/o codegen 3636.91 5.77 0.27 X - kurtosis w codegen 369.25 56.79 2.68 X + stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + stddev codegen=false 5630 / 5776 18.0 55.6 1.0X + stddev codegen=true 1259 / 1314 83.0 12.0 4.5X + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + kurtosis codegen=false 14847 / 15084 7.0 142.9 1.0X + kurtosis codegen=true 1652 / 2124 63.0 15.9 9.0X */ - benchmark.run() } - def testAggregateWithKey(values: Int): Unit = { - val benchmark = new Benchmark("Aggregate with keys", values) + ignore("aggregate with keys") { + val N = 20 << 20 - benchmark.addCase("Aggregate w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() - } - benchmark.addCase(s"Aggregate w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + runBenchmark("Aggregate w keys", N) { + sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - Aggregate w/o codegen 4254.38 4.93 1.00 X - Aggregate w codegen 2661.45 7.88 1.60 X + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X + Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X */ - benchmark.run() } - def testBroadcastHashJoin(values: Int): Unit = { - val benchmark = new Benchmark("BroadcastHashJoin", values) - + ignore("broadcast hash join") { + val N = 20 << 20 val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) - benchmark.addCase("BroadcastHashJoin w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() - } - benchmark.addCase(s"BroadcastHashJoin w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + runBenchmark("BroadcastHashJoin", N) { + sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() } /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X - BroadcastHashJoin w codegen 1028.40 10.20 2.97 X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + BroadcastHashJoin codegen=false 4405 / 6147 4.0 250.0 1.0X + BroadcastHashJoin codegen=true 1857 / 1878 11.0 90.9 2.4X */ - benchmark.run() } - def testBytesToBytesMap(values: Int): Unit = { - val benchmark = new Benchmark("BytesToBytesMap", values) + ignore("hash and BytesToBytesMap") { + val N = 50 << 20 + + val benchmark = new Benchmark("BytesToBytesMap", N) benchmark.addCase("hash") { iter => var i = 0 @@ -167,7 +152,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(2) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var s = 0 - while (i < values) { + while (i < N) { key.setInt(0, i % 1000) val h = Murmur3_x86_32.hashUnsafeWords( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) @@ -194,7 +179,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(2) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - while (i < values) { + while (i < N) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) if (loc.isDefined) { @@ -212,21 +197,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - hash 662.06 79.19 1.00 X - BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X - BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X + BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + hash 628 / 661 83.0 12.0 1.0X + BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X + BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X */ benchmark.run() } - - // These benchmark are skipped in normal build - ignore("benchmark") { - // testWholeStage(200 << 20) - // testStatFunctions(20 << 20) - // testAggregateWithKey(20 << 20) - // testBytesToBytesMap(50 << 20) - // testBroadcastHashJoin(10 << 20) - } } From a8e2ba776b20c8054918af646d8228bba1b87c9b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 3 Feb 2016 17:43:14 -0800 Subject: [PATCH 074/107] [SPARK-13152][CORE] Fix task metrics deprecation warning Make an internal non-deprecated version of incBytesRead and incRecordsRead so we don't have unecessary deprecation warnings in our build. Right now incBytesRead and incRecordsRead are marked as deprecated and for internal use only. We should make private[spark] versions which are not deprecated and switch to those internally so as to not clutter up the warning messages when building. cc andrewor14 who did the initial deprecation Author: Holden Karau Closes #11056 from holdenk/SPARK-13152-fix-task-metrics-deprecation-warnings. --- core/src/main/scala/org/apache/spark/CacheManager.scala | 4 ++-- .../main/scala/org/apache/spark/executor/InputMetrics.scala | 5 +++++ core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 4 ++-- core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- core/src/main/scala/org/apache/spark/util/JsonProtocol.scala | 4 ++-- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 4 ++-- 6 files changed, 15 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index fa8e2b953835b..923ff411ce252 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,12 +44,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Some(blockResult) => // Partition is already materialized, so just return its values val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) - existingMetrics.incBytesRead(blockResult.bytes) + existingMetrics.incBytesReadInternal(blockResult.bytes) val iter = blockResult.data.asInstanceOf[Iterator[T]] new InterruptibleIterator[T](context, iter) { override def next(): T = { - existingMetrics.incRecordsRead(1) + existingMetrics.incRecordsReadInternal(1) delegate.next() } } diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index ed9e157ce758b..6d30d3c76a9fb 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -81,10 +81,15 @@ class InputMetrics private ( */ def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) + // Once incBytesRead & intRecordsRead is ready to be removed from the public API + // we can remove the internal versions and make the previous public API private. + // This has been done to suppress warnings when building. @deprecated("incrementing input metrics is for internal use only", "2.0.0") def incBytesRead(v: Long): Unit = _bytesRead.add(v) + private[spark] def incBytesReadInternal(v: Long): Unit = _bytesRead.add(v) @deprecated("incrementing input metrics is for internal use only", "2.0.0") def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + private[spark] def incRecordsReadInternal(v: Long): Unit = _recordsRead.add(v) private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = _readMethod.setValue(v.toString) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e2ebd7f00d0d5..805cd9fe1f638 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -260,7 +260,7 @@ class HadoopRDD[K, V]( finished = true } if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -292,7 +292,7 @@ class HadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.inputSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e71d3405c0ead..f23da39eb90de 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -188,7 +188,7 @@ class NewHadoopRDD[K, V]( } havePair = false if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -219,7 +219,7 @@ class NewHadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a2487eeb0483a..38e6478d80f03 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -811,8 +811,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Input Metrics").foreach { inJson => val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String]) val inputMetrics = metrics.registerInputMetrics(readMethod) - inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) - inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + inputMetrics.incBytesReadInternal((inJson \ "Bytes Read").extract[Long]) + inputMetrics.incRecordsReadInternal((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) } // Updated blocks diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 9703b16c86f90..3605150b3b767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -214,7 +214,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } havePair = false if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -246,7 +246,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) From a64831124c215f56f124747fa241560c70cf0a36 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 3 Feb 2016 19:32:41 -0800 Subject: [PATCH 075/107] [SPARK-13079][SQL] Extend and implement InMemoryCatalog This is a step towards consolidating `SQLContext` and `HiveContext`. This patch extends the existing Catalog API added in #10982 to include methods for handling table partitions. In particular, a partition is identified by `PartitionSpec`, which is just a `Map[String, String]`. The Catalog is still not used by anything yet, but its API is now more or less complete and an implementation is fully tested. About 200 lines are test code. Author: Andrew Or Closes #11069 from andrewor14/catalog. --- .../catalyst/catalog/InMemoryCatalog.scala | 129 ++++++++--- .../sql/catalyst/catalog/interface.scala | 40 +++- .../catalyst/catalog/CatalogTestCases.scala | 206 +++++++++++++++++- 3 files changed, 328 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9e6dfb7e9506f..38be61c52a95e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,9 +28,10 @@ import org.apache.spark.sql.AnalysisException * All public methods should be synchronized for thread-safety. */ class InMemoryCatalog extends Catalog { + import Catalog._ private class TableDesc(var table: Table) { - val partitions = new mutable.HashMap[String, TablePartition] + val partitions = new mutable.HashMap[PartitionSpec, TablePartition] } private class DatabaseDesc(var db: Database) { @@ -46,13 +47,20 @@ class InMemoryCatalog extends Catalog { } private def existsFunction(db: String, funcName: String): Boolean = { + assertDbExists(db) catalog(db).functions.contains(funcName) } private def existsTable(db: String, table: String): Boolean = { + assertDbExists(db) catalog(db).tables.contains(table) } + private def existsPartition(db: String, table: String, spec: PartitionSpec): Boolean = { + assertTableExists(db, table) + catalog(db).tables(table).partitions.contains(spec) + } + private def assertDbExists(db: String): Unit = { if (!catalog.contains(db)) { throw new AnalysisException(s"Database $db does not exist") @@ -60,16 +68,20 @@ class InMemoryCatalog extends Catalog { } private def assertFunctionExists(db: String, funcName: String): Unit = { - assertDbExists(db) if (!existsFunction(db, funcName)) { - throw new AnalysisException(s"Function $funcName does not exists in $db database") + throw new AnalysisException(s"Function $funcName does not exist in $db database") } } private def assertTableExists(db: String, table: String): Unit = { - assertDbExists(db) if (!existsTable(db, table)) { - throw new AnalysisException(s"Table $table does not exists in $db database") + throw new AnalysisException(s"Table $table does not exist in $db database") + } + } + + private def assertPartitionExists(db: String, table: String, spec: PartitionSpec): Unit = { + if (!existsPartition(db, table, spec)) { + throw new AnalysisException(s"Partition does not exist in database $db table $table: $spec") } } @@ -77,9 +89,11 @@ class InMemoryCatalog extends Catalog { // Databases // -------------------------------------------------------------------------- - override def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit = synchronized { + override def createDatabase( + dbDefinition: Database, + ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Database ${dbDefinition.name} already exists.") } } else { @@ -88,9 +102,9 @@ class InMemoryCatalog extends Catalog { } override def dropDatabase( - db: String, - ignoreIfNotExists: Boolean, - cascade: Boolean): Unit = synchronized { + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = synchronized { if (catalog.contains(db)) { if (!cascade) { // If cascade is false, make sure the database is empty. @@ -133,11 +147,13 @@ class InMemoryCatalog extends Catalog { // Tables // -------------------------------------------------------------------------- - override def createTable(db: String, tableDefinition: Table, ifNotExists: Boolean) - : Unit = synchronized { + override def createTable( + db: String, + tableDefinition: Table, + ignoreIfExists: Boolean): Unit = synchronized { assertDbExists(db) if (existsTable(db, tableDefinition.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database") } } else { @@ -145,8 +161,10 @@ class InMemoryCatalog extends Catalog { } } - override def dropTable(db: String, table: String, ignoreIfNotExists: Boolean) - : Unit = synchronized { + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean): Unit = synchronized { assertDbExists(db) if (existsTable(db, table)) { catalog(db).tables.remove(table) @@ -190,14 +208,67 @@ class InMemoryCatalog extends Catalog { // Partitions // -------------------------------------------------------------------------- - override def alterPartition(db: String, table: String, part: TablePartition) - : Unit = synchronized { - throw new UnsupportedOperationException + override def createPartitions( + db: String, + table: String, + parts: Seq[TablePartition], + ignoreIfExists: Boolean): Unit = synchronized { + assertTableExists(db, table) + val existingParts = catalog(db).tables(table).partitions + if (!ignoreIfExists) { + val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec } + if (dupSpecs.nonEmpty) { + val dupSpecsStr = dupSpecs.mkString("\n===\n") + throw new AnalysisException( + s"The following partitions already exist in database $db table $table:\n$dupSpecsStr") + } + } + parts.foreach { p => existingParts.put(p.spec, p) } + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[PartitionSpec], + ignoreIfNotExists: Boolean): Unit = synchronized { + assertTableExists(db, table) + val existingParts = catalog(db).tables(table).partitions + if (!ignoreIfNotExists) { + val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s } + if (missingSpecs.nonEmpty) { + val missingSpecsStr = missingSpecs.mkString("\n===\n") + throw new AnalysisException( + s"The following partitions do not exist in database $db table $table:\n$missingSpecsStr") + } + } + partSpecs.foreach(existingParts.remove) } - override def alterPartitions(db: String, table: String, parts: Seq[TablePartition]) - : Unit = synchronized { - throw new UnsupportedOperationException + override def alterPartition( + db: String, + table: String, + spec: Map[String, String], + newPart: TablePartition): Unit = synchronized { + assertPartitionExists(db, table, spec) + val existingParts = catalog(db).tables(table).partitions + if (spec != newPart.spec) { + // Also a change in specs; remove the old one and add the new one back + existingParts.remove(spec) + } + existingParts.put(newPart.spec, newPart) + } + + override def getPartition( + db: String, + table: String, + spec: Map[String, String]): TablePartition = synchronized { + assertPartitionExists(db, table, spec) + catalog(db).tables(table).partitions(spec) + } + + override def listPartitions(db: String, table: String): Seq[TablePartition] = synchronized { + assertTableExists(db, table) + catalog(db).tables(table).partitions.values.toSeq } // -------------------------------------------------------------------------- @@ -205,11 +276,12 @@ class InMemoryCatalog extends Catalog { // -------------------------------------------------------------------------- override def createFunction( - db: String, func: Function, ifNotExists: Boolean): Unit = synchronized { + db: String, + func: Function, + ignoreIfExists: Boolean): Unit = synchronized { assertDbExists(db) - if (existsFunction(db, func.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Function $func already exists in $db database") } } else { @@ -222,14 +294,16 @@ class InMemoryCatalog extends Catalog { catalog(db).functions.remove(funcName) } - override def alterFunction(db: String, funcName: String, funcDefinition: Function) - : Unit = synchronized { + override def alterFunction( + db: String, + funcName: String, + funcDefinition: Function): Unit = synchronized { assertFunctionExists(db, funcName) if (funcName != funcDefinition.name) { // Also a rename; remove the old one and add the new one back catalog(db).functions.remove(funcName) } - catalog(db).functions.put(funcName, funcDefinition) + catalog(db).functions.put(funcDefinition.name, funcDefinition) } override def getFunction(db: String, funcName: String): Function = synchronized { @@ -239,7 +313,6 @@ class InMemoryCatalog extends Catalog { override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { assertDbExists(db) - val regex = pattern.replaceAll("\\*", ".*").r filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index a6caf91f3304b..b4d7dd2f4e31c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -29,17 +29,15 @@ import org.apache.spark.sql.AnalysisException * Implementations should throw [[AnalysisException]] when table or database don't exist. */ abstract class Catalog { + import Catalog._ // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit + def createDatabase(dbDefinition: Database, ignoreIfExists: Boolean): Unit - def dropDatabase( - db: String, - ignoreIfNotExists: Boolean, - cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit def alterDatabase(db: String, dbDefinition: Database): Unit @@ -71,11 +69,28 @@ abstract class Catalog { // Partitions // -------------------------------------------------------------------------- - // TODO: need more functions for partitioning. + def createPartitions( + db: String, + table: String, + parts: Seq[TablePartition], + ignoreIfExists: Boolean): Unit - def alterPartition(db: String, table: String, part: TablePartition): Unit + def dropPartitions( + db: String, + table: String, + parts: Seq[PartitionSpec], + ignoreIfNotExists: Boolean): Unit - def alterPartitions(db: String, table: String, parts: Seq[TablePartition]): Unit + def alterPartition( + db: String, + table: String, + spec: PartitionSpec, + newPart: TablePartition): Unit + + def getPartition(db: String, table: String, spec: PartitionSpec): TablePartition + + // TODO: support listing by pattern + def listPartitions(db: String, table: String): Seq[TablePartition] // -------------------------------------------------------------------------- // Functions @@ -132,11 +147,11 @@ case class Column( /** * A partition (Hive style) defined in the catalog. * - * @param values values for the partition columns + * @param spec partition spec values indexed by column name * @param storage storage format of the partition */ case class TablePartition( - values: Seq[String], + spec: Catalog.PartitionSpec, storage: StorageFormat ) @@ -176,3 +191,8 @@ case class Database( locationUri: String, properties: Map[String, String] ) + + +object Catalog { + type PartitionSpec = Map[String, String] +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index ab9d5ac8a20eb..0d8434323fcbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -27,6 +27,11 @@ import org.apache.spark.sql.AnalysisException * Implementations of the [[Catalog]] interface can create test suites by extending this. */ abstract class CatalogTestCases extends SparkFunSuite { + private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map.empty[String, String]) + private val part1 = TablePartition(Map[String, String]("a" -> "1"), storageFormat) + private val part2 = TablePartition(Map[String, String]("b" -> "2"), storageFormat) + private val part3 = TablePartition(Map[String, String]("c" -> "3"), storageFormat) + private val funcClass = "org.apache.spark.myFunc" protected def newEmptyCatalog(): Catalog @@ -41,16 +46,16 @@ abstract class CatalogTestCases extends SparkFunSuite { */ private def newBasicCatalog(): Catalog = { val catalog = newEmptyCatalog() - catalog.createDatabase(newDb("db1"), ifNotExists = false) - catalog.createDatabase(newDb("db2"), ifNotExists = false) - + catalog.createDatabase(newDb("db1"), ignoreIfExists = false) + catalog.createDatabase(newDb("db2"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + catalog.createPartitions("db2", "tbl2", Seq(part1, part2), ignoreIfExists = false) catalog } - private def newFunc(): Function = Function("funcname", "org.apache.spark.MyFunc") + private def newFunc(): Function = Function("funcname", funcClass) private def newDb(name: String = "default"): Database = Database(name, name + " description", "uri", Map.empty) @@ -59,7 +64,7 @@ abstract class CatalogTestCases extends SparkFunSuite { Table(name, "", Seq.empty, Seq.empty, Seq.empty, null, 0, Map.empty, "EXTERNAL_TABLE", 0, 0, None, None) - private def newFunc(name: String): Function = Function(name, "class.name") + private def newFunc(name: String): Function = Function(name, funcClass) // -------------------------------------------------------------------------- // Databases @@ -67,10 +72,10 @@ abstract class CatalogTestCases extends SparkFunSuite { test("basic create, drop and list databases") { val catalog = newEmptyCatalog() - catalog.createDatabase(newDb(), ifNotExists = false) + catalog.createDatabase(newDb(), ignoreIfExists = false) assert(catalog.listDatabases().toSet == Set("default")) - catalog.createDatabase(newDb("default2"), ifNotExists = false) + catalog.createDatabase(newDb("default2"), ignoreIfExists = false) assert(catalog.listDatabases().toSet == Set("default", "default2")) } @@ -253,11 +258,194 @@ abstract class CatalogTestCases extends SparkFunSuite { // Partitions // -------------------------------------------------------------------------- - // TODO: Add tests cases for partitions + test("basic create and list partitions") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createTable("mydb", newTable("mytbl"), ignoreIfExists = false) + catalog.createPartitions("mydb", "mytbl", Seq(part1, part2), ignoreIfExists = false) + assert(catalog.listPartitions("mydb", "mytbl").toSet == Set(part1, part2)) + } + + test("create partitions when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createPartitions("does_not_exist", "tbl1", Seq(), ignoreIfExists = false) + } + intercept[AnalysisException] { + catalog.createPartitions("db2", "does_not_exist", Seq(), ignoreIfExists = false) + } + } + + test("create partitions that already exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = false) + } + catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = true) + } + + test("drop partitions") { + val catalog = newBasicCatalog() + assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part1, part2)) + catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) + assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part2)) + val catalog2 = newBasicCatalog() + assert(catalog2.listPartitions("db2", "tbl2").toSet == Set(part1, part2)) + catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + assert(catalog2.listPartitions("db2", "tbl2").isEmpty) + } + + test("drop partitions when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) + } + intercept[AnalysisException] { + catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false) + } + } + + test("drop partitions that do not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) + } + catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) + } + + test("get partition") { + val catalog = newBasicCatalog() + assert(catalog.getPartition("db2", "tbl2", part1.spec) == part1) + assert(catalog.getPartition("db2", "tbl2", part2.spec) == part2) + intercept[AnalysisException] { + catalog.getPartition("db2", "tbl1", part3.spec) + } + } + + test("get partition when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getPartition("does_not_exist", "tbl1", part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition("db2", "does_not_exist", part1.spec) + } + } + + test("alter partitions") { + val catalog = newBasicCatalog() + val partSameSpec = part1.copy(storage = storageFormat.copy(serde = "myserde")) + val partNewSpec = part1.copy(spec = Map("x" -> "10")) + // alter but keep spec the same + catalog.alterPartition("db2", "tbl2", part1.spec, partSameSpec) + assert(catalog.getPartition("db2", "tbl2", part1.spec) == partSameSpec) + // alter and change spec + catalog.alterPartition("db2", "tbl2", part1.spec, partNewSpec) + intercept[AnalysisException] { + catalog.getPartition("db2", "tbl2", part1.spec) + } + assert(catalog.getPartition("db2", "tbl2", partNewSpec.spec) == partNewSpec) + } + + test("alter partition when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.alterPartition("does_not_exist", "tbl1", part1.spec, part1) + } + intercept[AnalysisException] { + catalog.alterPartition("db2", "does_not_exist", part1.spec, part1) + } + } // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- - // TODO: Add tests cases for functions + test("basic create and list functions") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createFunction("mydb", newFunc("myfunc"), ignoreIfExists = false) + assert(catalog.listFunctions("mydb", "*").toSet == Set("myfunc")) + } + + test("create function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createFunction("does_not_exist", newFunc(), ignoreIfExists = false) + } + } + + test("create function that already exists") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + } + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = true) + } + + test("drop function") { + val catalog = newBasicCatalog() + assert(catalog.listFunctions("db2", "*").toSet == Set("func1")) + catalog.dropFunction("db2", "func1") + assert(catalog.listFunctions("db2", "*").isEmpty) + } + + test("drop function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropFunction("does_not_exist", "something") + } + } + + test("drop function that does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropFunction("db2", "does_not_exist") + } + } + + test("get function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1") == newFunc("func1")) + intercept[AnalysisException] { + catalog.getFunction("db2", "does_not_exist") + } + } + + test("get function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getFunction("does_not_exist", "func1") + } + } + + test("alter function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1").className == funcClass) + // alter func but keep name + catalog.alterFunction("db2", "func1", newFunc("func1").copy(className = "muhaha")) + assert(catalog.getFunction("db2", "func1").className == "muhaha") + // alter func and change name + catalog.alterFunction("db2", "func1", newFunc("funcky")) + intercept[AnalysisException] { + catalog.getFunction("db2", "func1") + } + assert(catalog.getFunction("db2", "funcky").className == funcClass) + } + + test("alter function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.alterFunction("does_not_exist", "func1", newFunc()) + } + } + + test("list functions") { + val catalog = newBasicCatalog() + catalog.createFunction("db2", newFunc("func2"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("not_me"), ignoreIfExists = false) + assert(catalog.listFunctions("db2", "*").toSet == Set("func1", "func2", "not_me")) + assert(catalog.listFunctions("db2", "func*").toSet == Set("func1", "func2")) + } + } From 0f81318ae217346c20894572795e1a9cee2ebc8f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 3 Feb 2016 21:05:53 -0800 Subject: [PATCH 076/107] [SPARK-12828][SQL] add natural join support Jira: https://issues.apache.org/jira/browse/SPARK-12828 Author: Daoyuan Wang Closes #10762 from adrian-wang/naturaljoin. --- .../sql/catalyst/parser/FromClauseParser.g | 23 +++-- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 2 + .../sql/catalyst/parser/SparkSqlParser.g | 4 + .../spark/sql/catalyst/CatalystQl.scala | 4 + .../sql/catalyst/analysis/Analyzer.scala | 43 +++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../spark/sql/catalyst/plans/joinTypes.scala | 4 + .../plans/logical/basicOperators.scala | 10 ++- .../analysis/ResolveNaturalJoinSuite.scala | 90 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 24 +++++ 11 files changed, 198 insertions(+), 11 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index 6d76afcd4ac07..e83f8a7cd1b5c 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -117,15 +117,20 @@ joinToken @init { gParent.pushMsg("join type specifier", state); } @after { gParent.popMsg(state); } : - KW_JOIN -> TOK_JOIN - | KW_INNER KW_JOIN -> TOK_JOIN - | COMMA -> TOK_JOIN - | KW_CROSS KW_JOIN -> TOK_CROSSJOIN - | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN - | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN - | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN - | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN - | KW_ANTI KW_JOIN -> TOK_ANTIJOIN + KW_JOIN -> TOK_JOIN + | KW_INNER KW_JOIN -> TOK_JOIN + | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN + | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN + | COMMA -> TOK_JOIN + | KW_CROSS KW_JOIN -> TOK_CROSSJOIN + | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN + | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN + | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN + | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN + | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN + | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN + | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN + | KW_ANTI KW_JOIN -> TOK_ANTIJOIN ; lateralView diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index 1d07a27353dcb..fd1ad59207e31 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -335,6 +335,8 @@ KW_CACHE: 'CACHE'; KW_UNCACHE: 'UNCACHE'; KW_DFS: 'DFS'; +KW_NATURAL: 'NATURAL'; + // Operators // NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 6591f6b0f56ce..9935678ca2ca2 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -96,6 +96,10 @@ TOK_RIGHTOUTERJOIN; TOK_FULLOUTERJOIN; TOK_UNIQUEJOIN; TOK_CROSSJOIN; +TOK_NATURALJOIN; +TOK_NATURALLEFTOUTERJOIN; +TOK_NATURALRIGHTOUTERJOIN; +TOK_NATURALFULLOUTERJOIN; TOK_LOAD; TOK_EXPORT; TOK_IMPORT; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 7ce2407913ade..a42360d5629f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -520,6 +520,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_LEFTSEMIJOIN" => LeftSemi case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) + case "TOK_NATURALJOIN" => NaturalJoin(Inner) + case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) + case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) + case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) } Join(nodeToRelation(relation1), nodeToRelation(relation2), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a983dc1cdfebe..b30ed5928fd56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -81,6 +82,7 @@ class Analyzer( ResolveAliases :: ResolveWindowOrder :: ResolveWindowFrame :: + ResolveNaturalJoin :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -1230,6 +1232,47 @@ class Analyzer( } } } + + /** + * Removes natural joins by calculating output columns based on output from two sides, + * Then apply a Project on a normal Join to eliminate natural join. + */ + object ResolveNaturalJoin extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // Should not skip unresolved nodes because natural join is always unresolved. + case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => + // find common column names from both sides, should be treated like usingColumns + val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) + val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) + val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) + val joinPairs = leftKeys.zip(rightKeys) + // Add joinPairs to joinConditions + val newCondition = (condition ++ joinPairs.map { + case (l, r) => EqualTo(l, r) + }).reduceLeftOption(And) + // columns not in joinPairs + val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) + val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) + // we should only keep unique columns(depends on joinType) for joinCols + val projectList = joinType match { + case LeftOuter => + leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) + case RightOuter => + rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput + case FullOuter => + // in full outer join, joinCols should be non-null if there is. + val joinedCols = joinPairs.map { + case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() + } + joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ + rUniqueOutput.map(_.withNullability(true)) + case _ => + rightKeys ++ lUniqueOutput ++ rUniqueOutput + } + // use Project to trim unnecessary fields + Project(projectList, Join(left, right, joinType, newCondition)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f156b5d10acc2..4ecee75048248 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -905,6 +905,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) case FullOuter => f // DO Nothing for Full Outer Join + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") } // push down the join filter into sub query scanning if applicable @@ -939,6 +940,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index a5f6764aef7ce..b10f1e63a73e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -60,3 +60,7 @@ case object FullOuter extends JoinType { case object LeftSemi extends JoinType { override def sql: String = "LEFT SEMI" } + +case class NaturalJoin(tpe: JoinType) extends JoinType { + override def sql: String = "NATURAL " + tpe.sql +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8150ff8434762..03a79520cbd3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -250,12 +250,20 @@ case class Join( def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. - override lazy val resolved: Boolean = { + // NaturalJoin should be ready for resolution only if everything else is resolved here + lazy val resolvedExceptNatural: Boolean = { childrenResolved && expressions.forall(_.resolved) && duplicateResolved && condition.forall(_.dataType == BooleanType) } + + // if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need + // to eliminate natural before we mark it resolved. + override lazy val resolved: Boolean = joinType match { + case NaturalJoin(_) => false + case _ => resolvedExceptNatural + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala new file mode 100644 index 0000000000000..a6554fbc414b5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +class ResolveNaturalJoinSuite extends AnalysisTest { + lazy val a = 'a.string + lazy val b = 'b.string + lazy val c = 'c.string + lazy val aNotNull = a.notNull + lazy val bNotNull = b.notNull + lazy val cNotNull = c.notNull + lazy val r1 = LocalRelation(a, b) + lazy val r2 = LocalRelation(a, c) + lazy val r3 = LocalRelation(aNotNull, bNotNull) + lazy val r4 = LocalRelation(bNotNull, cNotNull) + + test("natural inner join") { + val plan = r1.join(r2, NaturalJoin(Inner), None) + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural left join") { + val plan = r1.join(r2, NaturalJoin(LeftOuter), None) + val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural right join") { + val plan = r1.join(r2, NaturalJoin(RightOuter), None) + val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural full outer join") { + val plan = r1.join(r2, NaturalJoin(FullOuter), None) + val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select( + Alias(Coalesce(Seq(a, a)), "a")(), b, c) + checkAnalysis(plan, expected) + } + + test("natural inner join with no nullability") { + val plan = r3.join(r4, NaturalJoin(Inner), None) + val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, cNotNull) + checkAnalysis(plan, expected) + } + + test("natural left join with no nullability") { + val plan = r3.join(r4, NaturalJoin(LeftOuter), None) + val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, c) + checkAnalysis(plan, expected) + } + + test("natural right join with no nullability") { + val plan = r3.join(r4, NaturalJoin(RightOuter), None) + val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, a, cNotNull) + checkAnalysis(plan, expected) + } + + test("natural full outer join with no nullability") { + val plan = r3.join(r4, NaturalJoin(FullOuter), None) + val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select( + Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c) + checkAnalysis(plan, expected) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 84203bbfef66a..f15b926bd27cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -474,6 +474,7 @@ class DataFrame private[sql]( val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true) Alias(Coalesce(Seq(leftCol, rightCol)), col)() } + case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.") } // The nullability of output of joined could be different than original column, // so we can only compare them by exprId diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 79bfd4b44b70a..8ef7b61314a56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2075,4 +2075,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } + + test("natural join") { + val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1") + val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2") + withTempTable("nt1", "nt2") { + df1.registerTempTable("nt1") + df2.registerTempTable("nt2") + checkAnswer( + sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), + Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) + + checkAnswer( + sql("SELECT count(*) FROM nt1 natural full outer join nt2"), + Row(4) :: Nil) + } + } } From c2c956bcd1a75fd01868ee9ad2939a6d3de52bc2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 3 Feb 2016 21:19:44 -0800 Subject: [PATCH 077/107] [ML][DOC] fix wrong api link in ml onevsrest minor fix for api link in ml onevsrest Author: Yuhao Yang Closes #11068 from hhbyyh/onevsrestDoc. --- docs/ml-classification-regression.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 8ffc997b4bf5a..9569a06472cbf 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -289,7 +289,7 @@ The example below demonstrates how to load the
        -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.OneVsRest) for more details. {% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %}
        From d39087147ff1052b623cdba69ffbde28b266745f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 3 Feb 2016 23:17:51 -0800 Subject: [PATCH 078/107] [SPARK-13113] [CORE] Remove unnecessary bit operation when decoding page number JIRA: https://issues.apache.org/jira/browse/SPARK-13113 As we shift bits right, looks like the bitwise AND operation is unnecessary. Author: Liang-Chi Hsieh Closes #11002 from viirya/improve-decodepagenumber. --- .../main/java/org/apache/spark/memory/TaskMemoryManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d31eb449eb82e..d2a88864f7ac9 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -312,7 +312,7 @@ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) @VisibleForTesting public static int decodePageNumber(long pagePlusOffsetAddress) { - return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + return (int) (pagePlusOffsetAddress >>> OFFSET_BITS); } private static long decodeOffset(long pagePlusOffsetAddress) { From dee801adb78d6abd0abbf76b4dfa71aa296b4f0b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Feb 2016 23:43:48 -0800 Subject: [PATCH 079/107] [SPARK-12828][SQL] Natural join follow-up This is a small addendum to #10762 to make the code more robust again future changes. Author: Reynold Xin Closes #11070 from rxin/SPARK-12828-natural-join. --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++-------- .../spark/sql/catalyst/plans/joinTypes.scala | 2 ++ .../analysis/ResolveNaturalJoinSuite.scala | 6 +++--- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b30ed5928fd56..b59eb12419c45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1239,21 +1239,23 @@ class Analyzer( */ object ResolveNaturalJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Should not skip unresolved nodes because natural join is always unresolved. case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => - // find common column names from both sides, should be treated like usingColumns + // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) val joinPairs = leftKeys.zip(rightKeys) + // Add joinPairs to joinConditions val newCondition = (condition ++ joinPairs.map { case (l, r) => EqualTo(l, r) - }).reduceLeftOption(And) + }).reduceOption(And) + // columns not in joinPairs val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) - // we should only keep unique columns(depends on joinType) for joinCols + + // the output list looks like: join keys, columns from left, columns from right val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) @@ -1261,13 +1263,14 @@ class Analyzer( rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput case FullOuter => // in full outer join, joinCols should be non-null if there is. - val joinedCols = joinPairs.map { - case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() - } - joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ + val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } + joinedCols ++ + lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) - case _ => + case Inner => rightKeys ++ lUniqueOutput ++ rUniqueOutput + case _ => + sys.error("Unsupported natural join type " + joinType) } // use Project to trim unnecessary fields Project(projectList, Join(left, right, joinType, newCondition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index b10f1e63a73e2..27a75326eba07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -62,5 +62,7 @@ case object LeftSemi extends JoinType { } case class NaturalJoin(tpe: JoinType) extends JoinType { + require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), + "Unsupported natural join type " + tpe) override def sql: String = "NATURAL " + tpe.sql } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index a6554fbc414b5..fcf4ac1967a53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -30,10 +30,10 @@ class ResolveNaturalJoinSuite extends AnalysisTest { lazy val aNotNull = a.notNull lazy val bNotNull = b.notNull lazy val cNotNull = c.notNull - lazy val r1 = LocalRelation(a, b) - lazy val r2 = LocalRelation(a, c) + lazy val r1 = LocalRelation(b, a) + lazy val r2 = LocalRelation(c, a) lazy val r3 = LocalRelation(aNotNull, bNotNull) - lazy val r4 = LocalRelation(bNotNull, cNotNull) + lazy val r4 = LocalRelation(cNotNull, bNotNull) test("natural inner join") { val plan = r1.join(r2, NaturalJoin(Inner), None) From 2eaeafe8a2aa31be9b230b8d53d3baccd32535b1 Mon Sep 17 00:00:00 2001 From: Charles Allen Date: Thu, 4 Feb 2016 10:27:25 -0800 Subject: [PATCH 080/107] [SPARK-12330][MESOS] Fix mesos coarse mode cleanup In the current implementation the mesos coarse scheduler does not wait for the mesos tasks to complete before ending the driver. This causes a race where the task has to finish cleaning up before the mesos driver terminates it with a SIGINT (and SIGKILL after 3 seconds if the SIGINT doesn't work). This PR causes the mesos coarse scheduler to wait for the mesos tasks to finish (with a timeout defined by `spark.mesos.coarse.shutdown.ms`) This PR also fixes a regression caused by [SPARK-10987] whereby submitting a shutdown causes a race between the local shutdown procedure and the notification of the scheduler driver disconnection. If the scheduler driver disconnection wins the race, the coarse executor incorrectly exits with status 1 (instead of the proper status 0) With this patch the mesos coarse scheduler terminates properly, the executors clean up, and the tasks are reported as `FINISHED` in the Mesos console (as opposed to `KILLED` in < 1.6 or `FAILED` in 1.6 and later) Author: Charles Allen Closes #10319 from drcrallen/SPARK-12330. --- .../CoarseGrainedExecutorBackend.scala | 8 +++- .../mesos/CoarseMesosSchedulerBackend.scala | 39 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 136cf4a84d387..3b5cb18da1b26 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.util.{Failure, Success} @@ -42,6 +43,7 @@ private[spark] class CoarseGrainedExecutorBackend( env: SparkEnv) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { + private[this] val stopping = new AtomicBoolean(false) var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None @@ -102,19 +104,23 @@ private[spark] class CoarseGrainedExecutorBackend( } case StopExecutor => + stopping.set(true) logInfo("Driver commanded a shutdown") // Cannot shutdown here because an ack may need to be sent back to the caller. So send // a message to self to actually do the shutdown. self.send(Shutdown) case Shutdown => + stopping.set(true) executor.stop() stop() rpcEnv.shutdown() } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (driver.exists(_.address == remoteAddress)) { + if (stopping.get()) { + logInfo(s"Driver from $remoteAddress disconnected during shutdown") + } else if (driver.exists(_.address == remoteAddress)) { logError(s"Driver $remoteAddress disassociated! Shutting down.") System.exit(1) } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 2f095b86c69ef..722293bb7a53b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -19,11 +19,13 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{Collections, List => JList} +import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.base.Stopwatch import com.google.common.collect.HashBiMap import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} @@ -60,6 +62,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdown.ms", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdown.ms must be >= 0") + + // Synchronization protected by stateLock + private[this] var stopCalled: Boolean = false + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -245,6 +253,13 @@ private[spark] class CoarseMesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { + if (stopCalled) { + logDebug("Ignoring offers during shutdown") + // Driver should simply return a stopped status on race + // condition between this.stop() and completing here + offers.asScala.map(_.getId).foreach(d.declineOffer) + return + } val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers.asScala) { val offerAttributes = toAttributeMap(offer.getAttributesList) @@ -364,7 +379,29 @@ private[spark] class CoarseMesosSchedulerBackend( } override def stop() { - super.stop() + // Make sure we're not launching tasks during shutdown + stateLock.synchronized { + if (stopCalled) { + logWarning("Stop called multiple times, ignoring") + return + } + stopCalled = true + super.stop() + } + // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. + // See SPARK-12330 + val stopwatch = new Stopwatch() + stopwatch.start() + // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent + while (slaveIdsWithExecutors.nonEmpty && + stopwatch.elapsed(TimeUnit.MILLISECONDS) < shutdownTimeoutMS) { + Thread.sleep(100) + } + if (slaveIdsWithExecutors.nonEmpty) { + logWarning(s"Timed out waiting for ${slaveIdsWithExecutors.size} remaining executors " + + s"to terminate within $shutdownTimeoutMS ms. This may leave temporary files " + + "on the mesos nodes.") + } if (mesosDriver != null) { mesosDriver.stop() } From 62a7c28388539e6fc7d16ee3009f2cf79d8635bd Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 4 Feb 2016 10:29:38 -0800 Subject: [PATCH 081/107] [SPARK-13164][CORE] Replace deprecated synchronized buffer in core Building with scala 2.11 results in the warning trait SynchronizedBuffer in package mutable is deprecated: Synchronization via traits is deprecated as it is inherently unreliable. Consider java.util.concurrent.ConcurrentLinkedQueue as an alternative. Investigation shows we are already using ConcurrentLinkedQueue in other locations so switch our uses of SynchronizedBuffer to ConcurrentLinkedQueue. Author: Holden Karau Closes #11059 from holdenk/SPARK-13164-replace-deprecated-synchronized-buffer-in-core. --- .../org/apache/spark/ContextCleaner.scala | 26 +++++++++---------- .../spark/deploy/client/AppClientSuite.scala | 20 +++++++------- .../org/apache/spark/rpc/RpcEnvSuite.scala | 23 ++++++++-------- .../apache/spark/util/EventLoopSuite.scala | 10 +++---- 4 files changed, 40 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5a42299a0bf83..17014e4954f90 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,9 +18,9 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit} -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.JavaConverters._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} @@ -57,13 +57,11 @@ private class CleanupTaskWeakReference( */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] - with SynchronizedBuffer[CleanupTaskWeakReference] + private val referenceBuffer = new ConcurrentLinkedQueue[CleanupTaskWeakReference]() private val referenceQueue = new ReferenceQueue[AnyRef] - private val listeners = new ArrayBuffer[CleanerListener] - with SynchronizedBuffer[CleanerListener] + private val listeners = new ConcurrentLinkedQueue[CleanerListener]() private val cleaningThread = new Thread() { override def run() { keepCleaning() }} @@ -111,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener): Unit = { - listeners += listener + listeners.add(listener) } /** Start the cleaner. */ @@ -166,7 +164,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { - referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)) } /** Keep cleaning RDD, shuffle, and broadcast state. */ @@ -179,7 +177,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { synchronized { reference.map(_.task).foreach { task => logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get + referenceBuffer.remove(reference.get) task match { case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = blockOnCleanupTasks) @@ -206,7 +204,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) - listeners.foreach(_.rddCleaned(rddId)) + listeners.asScala.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { case e: Exception => logError("Error cleaning RDD " + rddId, e) @@ -219,7 +217,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) blockManagerMaster.removeShuffle(shuffleId, blocking) - listeners.foreach(_.shuffleCleaned(shuffleId)) + listeners.asScala.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { case e: Exception => logError("Error cleaning shuffle " + shuffleId, e) @@ -231,7 +229,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug(s"Cleaning broadcast $broadcastId") broadcastManager.unbroadcast(broadcastId, true, blocking) - listeners.foreach(_.broadcastCleaned(broadcastId)) + listeners.asScala.foreach(_.broadcastCleaned(broadcastId)) logDebug(s"Cleaned broadcast $broadcastId") } catch { case e: Exception => logError("Error cleaning broadcast " + broadcastId, e) @@ -243,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning accumulator " + accId) Accumulators.remove(accId) - listeners.foreach(_.accumCleaned(accId)) + listeners.asScala.foreach(_.accumCleaned(accId)) logInfo("Cleaned accumulator " + accId) } catch { case e: Exception => logError("Error cleaning accumulator " + accId, e) @@ -258,7 +256,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning rdd checkpoint data " + rddId) ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId) - listeners.foreach(_.checkpointCleaned(rddId)) + listeners.asScala.foreach(_.checkpointCleaned(rddId)) logInfo("Cleaned rdd checkpoint data " + rddId) } catch { diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index eb794b6739d5e..658779360b7a5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.deploy.client -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfterAll @@ -165,14 +167,14 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd /** Application Listener to collect events */ private class AppClientCollector extends AppClientListener with Logging { - val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val connectedIdList = new ConcurrentLinkedQueue[String]() @volatile var disconnectedCount: Int = 0 - val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String] - val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String] - val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val deadReasonList = new ConcurrentLinkedQueue[String]() + val execAddedList = new ConcurrentLinkedQueue[String]() + val execRemovedList = new ConcurrentLinkedQueue[String]() def connected(id: String): Unit = { - connectedIdList += id + connectedIdList.add(id) } def disconnected(): Unit = { @@ -182,7 +184,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } def dead(reason: String): Unit = { - deadReasonList += reason + deadReasonList.add(reason) } def executorAdded( @@ -191,11 +193,11 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd hostPort: String, cores: Int, memory: Int): Unit = { - execAddedList += id + execAddedList.add(id) } def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { - execRemovedList += id + execRemovedList.add(id) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6f4eda8b47dde..22048003882dd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.rpc import java.io.{File, NotSerializableException} import java.nio.charset.StandardCharsets.UTF_8 import java.util.UUID -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeoutException, TimeUnit} import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -490,30 +491,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { /** * Setup an [[RpcEndpoint]] to collect all network events. - * @return the [[RpcEndpointRef]] and an `Seq` that contains network events. + * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events. */ private def setupNetworkEndpoint( _env: RpcEnv, - name: String): (RpcEndpointRef, Seq[(Any, Any)]) = { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = { + val events = new ConcurrentLinkedQueue[(Any, Any)] val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => - case m => events += "receive" -> m + case m => events.add("receive" -> m) } override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress + events.add("onConnected" -> remoteAddress) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + events.add("onDisconnected" -> remoteAddress) } override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress + events.add("onNetworkError" -> remoteAddress) } }) @@ -560,7 +561,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { eventually(timeout(5 seconds), interval(5 millis)) { // We don't know the exact client address but at least we can verify the message type - assert(events.map(_._1).contains("onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onConnected")) } clientEnv.shutdown() @@ -568,8 +569,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { eventually(timeout(5 seconds), interval(5 millis)) { // We don't know the exact client address but at least we can verify the message type - assert(events.map(_._1).contains("onConnected")) - assert(events.map(_._1).contains("onDisconnected")) + assert(events.asScala.map(_._1).exists(_ == "onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onDisconnected")) } } finally { clientEnv.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index b207d497f33c2..6f7dddd4f760a 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch} -import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -31,11 +31,11 @@ import org.apache.spark.SparkFunSuite class EventLoopSuite extends SparkFunSuite with Timeouts { test("EventLoop") { - val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] + val buffer = new ConcurrentLinkedQueue[Int] val eventLoop = new EventLoop[Int]("test") { override def onReceive(event: Int): Unit = { - buffer += event + buffer.add(event) } override def onError(e: Throwable): Unit = {} @@ -43,7 +43,7 @@ class EventLoopSuite extends SparkFunSuite with Timeouts { eventLoop.start() (1 to 100).foreach(eventLoop.post) eventually(timeout(5 seconds), interval(5 millis)) { - assert((1 to 100) === buffer.toSeq) + assert((1 to 100) === buffer.asScala.toSeq) } eventLoop.stop() } From 4120bcbaffe92da40486b469334119ed12199f4f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 10:32:16 -0800 Subject: [PATCH 082/107] [SPARK-13162] Standalone mode does not respect initial executors Currently the Master would always set an application's initial executor limit to infinity. If the user specified `spark.dynamicAllocation.initialExecutors`, the config would not take effect. This is similar to #11047 but for standalone mode. Author: Andrew Or Closes #11054 from andrewor14/standalone-da-initial. --- .../spark/ExecutorAllocationManager.scala | 2 ++ .../spark/deploy/ApplicationDescription.scala | 3 +++ .../spark/deploy/master/ApplicationInfo.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 16 ++++++++++++---- .../StandaloneDynamicAllocationSuite.scala | 17 ++++++++++++++++- 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 3431fc13dcb4e..db143d7341ce4 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -231,6 +231,8 @@ private[spark] class ExecutorAllocationManager( } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) + + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 78bbd5c03f4a6..c5c5c60923f4e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -29,6 +29,9 @@ private[spark] case class ApplicationDescription( // short name of compression codec used when writing event logs, if any (e.g. lzf) eventLogCodec: Option[String] = None, coresPerExecutor: Option[Int] = None, + // number of executors this application wants to start with, + // only used if dynamic allocation is enabled + initialExecutorLimit: Option[Int] = None, user: String = System.getProperty("user.name", "")) { override def toString: String = "ApplicationDescription(" + name + ")" diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 7e2cf956c7253..4ffb5283e99a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -65,7 +65,7 @@ private[spark] class ApplicationInfo( appSource = new ApplicationSource(this) nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] - executorLimit = Integer.MAX_VALUE + executorLimit = desc.initialExecutorLimit.getOrElse(Integer.MAX_VALUE) appUIUrlAtHistoryServer = None } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 16f33163789ab..d209645610c12 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,11 +19,11 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} +import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -89,8 +89,16 @@ private[spark] class SparkDeploySchedulerBackend( args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, - command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) + // If we're using dynamic allocation, set our initial executor limit to 0 for now. + // ExecutorAllocationManager will send the real initial limit to the Master later. + val initialExecutorLimit = + if (Utils.isDynamicAllocationEnabled(conf)) { + Some(0) + } else { + None + } + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index fdada0777f9a9..b7ff5c9e8c0d3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -447,7 +447,23 @@ class StandaloneDynamicAllocationSuite apps = getApplications() // kill executor successfully assert(apps.head.executors.size === 1) + } + test("initial executor limit") { + val initialExecutorLimit = 1 + val myConf = appConf + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.initialExecutors", initialExecutorLimit.toString) + sc = new SparkContext(myConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === initialExecutorLimit) + assert(apps.head.getExecutorLimit === initialExecutorLimit) + } } // =============================== @@ -540,7 +556,6 @@ class StandaloneDynamicAllocationSuite val missingExecutors = masterExecutors.toSet.diff(driverExecutors.toSet).toSeq.sorted missingExecutors.foreach { id => // Fake an executor registration so the driver knows about us - val port = System.currentTimeMillis % 65536 val endpointRef = mock(classOf[RpcEndpointRef]) val mockAddress = mock(classOf[RpcAddress]) when(endpointRef.address).thenReturn(mockAddress) From 15205da817b24ef0e349ec24d84034dc30b501f8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 10:34:43 -0800 Subject: [PATCH 083/107] [SPARK-13053][TEST] Unignore tests in InternalAccumulatorSuite These were ignored because they are incorrectly written; they don't actually trigger stage retries, which is what the tests are testing. These tests are now rewritten to induce stage retries through fetch failures. Note: there were 2 tests before and now there's only 1. What happened? It turns out that the case where we only resubmit a subset of of the original missing partitions is very difficult to simulate in tests without potentially introducing flakiness. This is because the `DAGScheduler` removes all map outputs associated with a given executor when this happens, and we will need multiple executors to trigger this case, and sometimes the scheduler still removes map outputs from all executors. Author: Andrew Or Closes #10969 from andrewor14/unignore-accum-test. --- .../org/apache/spark/AccumulatorSuite.scala | 52 +++++-- .../spark/InternalAccumulatorSuite.scala | 128 +++++++++--------- 2 files changed, 102 insertions(+), 78 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index b8f2b96d7088d..e0fdd45973858 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -323,35 +323,60 @@ private[spark] object AccumulatorSuite { * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. */ private class SaveInfoListener extends SparkListener { - private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] - private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] - private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID + type StageId = Int + type StageAttemptId = Int - // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated + private val completedStageInfos = new ArrayBuffer[StageInfo] + private val completedTaskInfos = + new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]] + + // Callback to call when a job completes. Parameter is job ID. @GuardedBy("this") + private var jobCompletionCallback: () => Unit = null + private var calledJobCompletionCallback: Boolean = false private var exception: Throwable = null def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq - def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq + def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] = + completedTaskInfos.get((stageId, stageAttemptId)).getOrElse(Seq.empty[TaskInfo]) - /** Register a callback to be called on job end. */ - def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { - jobCompletionCallback = callback + /** + * If `jobCompletionCallback` is set, block until the next call has finished. + * If the callback failed with an exception, throw it. + */ + def awaitNextJobCompletion(): Unit = synchronized { + if (jobCompletionCallback != null) { + while (!calledJobCompletionCallback) { + wait() + } + calledJobCompletionCallback = false + if (exception != null) { + exception = null + throw exception + } + } } - /** Throw a stored exception, if any. */ - def maybeThrowException(): Unit = synchronized { - if (exception != null) { throw exception } + /** + * Register a callback to be called on job end. + * A call to this should be followed by [[awaitNextJobCompletion]]. + */ + def registerJobCompletionCallback(callback: () => Unit): Unit = { + jobCompletionCallback = callback } override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { if (jobCompletionCallback != null) { try { - jobCompletionCallback(jobEnd.jobId) + jobCompletionCallback() } catch { // Store any exception thrown here so we can throw them later in the main thread. // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test. case NonFatal(e) => exception = e + } finally { + calledJobCompletionCallback = true + notify() } } } @@ -361,7 +386,8 @@ private class SaveInfoListener extends SparkListener { } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - completedTaskInfos += taskEnd.taskInfo + completedTaskInfos.getOrElseUpdate( + (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 630b46f828df7..44a16e26f4935 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockStatus} @@ -160,7 +161,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { iter } // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => + listener.registerJobCompletionCallback { () => val stageInfos = listener.getCompletedStageInfos val taskInfos = listener.getCompletedTaskInfos assert(stageInfos.size === 1) @@ -179,6 +180,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) } rdd.count() + listener.awaitNextJobCompletion() } test("internal accumulators in multiple stages") { @@ -205,7 +207,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { iter } // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => + listener.registerJobCompletionCallback { () => // We ran 3 stages, and the accumulator values should be distinct val stageInfos = listener.getCompletedStageInfos assert(stageInfos.size === 3) @@ -220,13 +222,66 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { rdd.count() } - // TODO: these two tests are incorrect; they don't actually trigger stage retries. - ignore("internal accumulators in fully resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks - } + test("internal accumulators in resubmitted stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + + // Simulate fetch failures in order to trigger a stage retry. Here we run 1 job with + // 2 stages. On the second stage, we trigger a fetch failure on the first stage attempt. + // This should retry both stages in the scheduler. Note that we only want to fail the + // first stage attempt because we want the stage to eventually succeed. + val x = sc.parallelize(1 to 100, numPartitions) + .mapPartitions { iter => TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1; iter } + .groupBy(identity) + val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId + val rdd = x.mapPartitionsWithIndex { case (i, iter) => + // Fail the first stage attempt. Here we use the task attempt ID to determine this. + // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt + // ID that's < 2 * numPartitions belongs to the first attempt of this stage. + val taskContext = TaskContext.get() + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + if (isFirstStageAttempt) { + throw new FetchFailedException( + SparkEnv.get.blockManager.blockManagerId, + sid, + taskContext.partitionId(), + taskContext.partitionId(), + "simulated fetch failure") + } else { + iter + } + } - ignore("internal accumulators in partially resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { () => + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 4) // 1 shuffle map stage + 1 result stage, both are retried + val mapStageId = stageInfos.head.stageId + val mapStageInfo1stAttempt = stageInfos.head + val mapStageInfo2ndAttempt = { + stageInfos.tail.find(_.stageId == mapStageId).getOrElse { + fail("expected two attempts of the same shuffle map stage.") + } + } + val stageAccum1stAttempt = findTestAccum(mapStageInfo1stAttempt.accumulables.values) + val stageAccum2ndAttempt = findTestAccum(mapStageInfo2ndAttempt.accumulables.values) + // Both map stages should have succeeded, since the fetch failure happened in the + // result stage, not the map stage. This means we should get the accumulator updates + // from all partitions. + assert(stageAccum1stAttempt.value.get.asInstanceOf[Long] === numPartitions) + assert(stageAccum2ndAttempt.value.get.asInstanceOf[Long] === numPartitions) + // Because this test resubmitted the map stage with all missing partitions, we should have + // created a fresh set of internal accumulators in the 2nd stage attempt. Assert this is + // the case by comparing the accumulator IDs between the two attempts. + // Note: it would be good to also test the case where the map stage is resubmitted where + // only a subset of the original partitions are missing. However, this scenario is very + // difficult to construct without potentially introducing flakiness. + assert(stageAccum1stAttempt.id != stageAccum2ndAttempt.id) + } + rdd.count() + listener.awaitNextJobCompletion() } test("internal accumulators are registered for cleanups") { @@ -257,63 +312,6 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { } } - /** - * Test whether internal accumulators are merged properly if some tasks fail. - * TODO: make this actually retry the stage. - */ - private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { - val listener = new SaveInfoListener - val numPartitions = 10 - val numFailedPartitions = (0 until numPartitions).count(failCondition) - // This says use 1 core and retry tasks up to 2 times - sc = new SparkContext("local[1, 2]", "test") - sc.addSparkListener(listener) - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => - val taskContext = TaskContext.get() - taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1 - // Fail the first attempts of a subset of the tasks - if (failCondition(i) && taskContext.attemptNumber() == 0) { - throw new Exception("Failing a task intentionally.") - } - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findTestAccum(stageInfos.head.accumulables.values) - // If all partitions failed, then we would resubmit the whole stage again and create a - // fresh set of internal accumulators. Otherwise, these internal accumulators do count - // failed values, so we must include the failed values. - val expectedAccumValue = - if (numPartitions == numFailedPartitions) { - numPartitions - } else { - numPartitions + numFailedPartitions - } - assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findTestAccum(taskInfo.accumulables) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.asInstanceOf[Long] === 1L) - assert(taskAccum.value.isDefined) - Some(taskAccum.value.get.asInstanceOf[Long]) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None - } - } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) - } - rdd.count() - listener.maybeThrowException() - } - /** * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. */ From 085f510ae554e2739a38ee0bc7210c4ece902f3f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 11:07:06 -0800 Subject: [PATCH 084/107] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #7971 (requested by yhuai) Closes #8539 (requested by srowen) Closes #8746 (requested by yhuai) Closes #9288 (requested by andrewor14) Closes #9321 (requested by andrewor14) Closes #9935 (requested by JoshRosen) Closes #10442 (requested by andrewor14) Closes #10585 (requested by srowen) Closes #10785 (requested by srowen) Closes #10832 (requested by andrewor14) Closes #10941 (requested by marmbrus) Closes #11024 (requested by andrewor14) From 33212cb9a13a6012b4c19ccfc0fb3db75de304da Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 4 Feb 2016 11:08:50 -0800 Subject: [PATCH 085/107] [SPARK-13168][SQL] Collapse adjacent repartition operators Spark SQL should collapse adjacent `Repartition` operators and only keep the last one. Author: Josh Rosen Closes #11064 from JoshRosen/collapse-repartition. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 16 ++++++++++++++-- ...ingSuite.scala => CollapseProjectSuite.scala} | 4 ++-- .../catalyst/optimizer/FilterPushdownSuite.scala | 2 +- .../sql/catalyst/optimizer/JoinOrderSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 15 +++++++++++++-- .../org/apache/spark/sql/hive/SQLBuilder.scala | 4 ++-- 6 files changed, 33 insertions(+), 10 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{ProjectCollapsingSuite.scala => CollapseProjectSuite.scala} (96%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ecee75048248..a1ac93073916c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,7 +68,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PushPredicateThroughAggregate, ColumnPruning, // Operator combine - ProjectCollapsing, + CollapseRepartition, + CollapseProject, CombineFilters, CombineLimits, CombineUnions, @@ -322,7 +323,7 @@ object ColumnPruning extends Rule[LogicalPlan] { * Combines two adjacent [[Project]] operators into one and perform alias substitution, * merging the expressions into one single expression. */ -object ProjectCollapsing extends Rule[LogicalPlan] { +object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p @ Project(projectList1, Project(projectList2, child)) => @@ -390,6 +391,16 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } } +/** + * Combines adjacent [[Repartition]] operators by keeping only the last one. + */ +object CollapseRepartition extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case r @ Repartition(numPartitions, shuffle, Repartition(_, _, child)) => + Repartition(numPartitions, shuffle, child) + } +} + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given @@ -857,6 +868,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions into three categories based on the attributes required * to evaluate them. + * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala similarity index 96% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 85b6530481b03..f5fd5ca6beb15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -25,11 +25,11 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class ProjectCollapsingSuite extends PlanTest { +class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: - Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + Batch("CollapseProject", Once, CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index f9f3bd55aa578..b49ca928b6292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -42,7 +42,7 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala index 9b1e16c727647..858a0d8fde3ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -43,7 +43,7 @@ class JoinOrderSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 8fca5e2167d04..adaeb513bc1b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,8 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -223,6 +222,18 @@ class PlannerSuite extends SharedSQLContext { } } + test("collapse adjacent repartitions") { + val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) + def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length + assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1) + doubleRepartitioned.queryExecution.optimizedPlan match { + case r: Repartition => + assert(r.numPartitions === 5) + assert(r.shuffle === false) + } + } + // --- Unit tests of EnsureRequirements --------------------------------------------------------- // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 1654594538366..fc5725d6915ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -188,7 +188,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over // `Aggregate`s to perform type casting. This rule merges these `Project`s into // `Aggregate`s. - ProjectCollapsing, + CollapseProject, // Used to handle other auxiliary `Project`s added by analyzer (e.g. // `ResolveAggregateFunctions` rule) From c756bda477f458ba4aad7fdb2026263507e0ad9b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 12:04:54 -0800 Subject: [PATCH 086/107] [SPARK-12330][MESOS][HOTFIX] Rename timeout config The config already describes time and accepts a general format that is not restricted to ms. This commit renames the internal config to use a format that's consistent in Spark. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 722293bb7a53b..8ed7553258e7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -62,8 +62,8 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt - private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdown.ms", "10s") - .ensuring(_ >= 0, "spark.mesos.coarse.shutdown.ms must be >= 0") + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") // Synchronization protected by stateLock private[this] var stopCalled: Boolean = false From bd38dd6f75c4af0f8f32bb21a82da53fffa5e825 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 12:20:18 -0800 Subject: [PATCH 087/107] [SPARK-13079][SQL] InMemoryCatalog follow-ups This patch incorporates review feedback from #11069, which is already merged. Author: Andrew Or Closes #11080 from andrewor14/catalog-follow-ups. --- .../spark/sql/catalyst/catalog/interface.scala | 15 +++++++++++++++ .../sql/catalyst/catalog/CatalogTestCases.scala | 12 +++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b4d7dd2f4e31c..56aaa6bc6c2e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -39,6 +39,9 @@ abstract class Catalog { def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + /** + * Alter an existing database. This operation does not support renaming. + */ def alterDatabase(db: String, dbDefinition: Database): Unit def getDatabase(db: String): Database @@ -57,6 +60,9 @@ abstract class Catalog { def renameTable(db: String, oldName: String, newName: String): Unit + /** + * Alter an existing table. This operation does not support renaming. + */ def alterTable(db: String, table: String, tableDefinition: Table): Unit def getTable(db: String, table: String): Table @@ -81,6 +87,9 @@ abstract class Catalog { parts: Seq[PartitionSpec], ignoreIfNotExists: Boolean): Unit + /** + * Alter an existing table partition and optionally override its spec. + */ def alterPartition( db: String, table: String, @@ -100,6 +109,9 @@ abstract class Catalog { def dropFunction(db: String, funcName: String): Unit + /** + * Alter an existing function and optionally override its name. + */ def alterFunction(db: String, funcName: String, funcDefinition: Function): Unit def getFunction(db: String, funcName: String): Function @@ -194,5 +206,8 @@ case class Database( object Catalog { + /** + * Specifications of a table partition indexed by column name. + */ type PartitionSpec = Map[String, String] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 0d8434323fcbb..45c5ceecb0eef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -27,10 +27,10 @@ import org.apache.spark.sql.AnalysisException * Implementations of the [[Catalog]] interface can create test suites by extending this. */ abstract class CatalogTestCases extends SparkFunSuite { - private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map.empty[String, String]) - private val part1 = TablePartition(Map[String, String]("a" -> "1"), storageFormat) - private val part2 = TablePartition(Map[String, String]("b" -> "2"), storageFormat) - private val part3 = TablePartition(Map[String, String]("c" -> "3"), storageFormat) + private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map()) + private val part1 = TablePartition(Map("a" -> "1"), storageFormat) + private val part2 = TablePartition(Map("b" -> "2"), storageFormat) + private val part3 = TablePartition(Map("c" -> "3"), storageFormat) private val funcClass = "org.apache.spark.myFunc" protected def newEmptyCatalog(): Catalog @@ -42,6 +42,8 @@ abstract class CatalogTestCases extends SparkFunSuite { * db2 * - tbl1 * - tbl2 + * - part1 + * - part2 * - func1 */ private def newBasicCatalog(): Catalog = { @@ -50,8 +52,8 @@ abstract class CatalogTestCases extends SparkFunSuite { catalog.createDatabase(newDb("db2"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) - catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) catalog.createPartitions("db2", "tbl2", Seq(part1, part2), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) catalog } From 8e2f296306131e6c7c2f06d6672995d3ff8ab021 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 4 Feb 2016 12:43:16 -0800 Subject: [PATCH 088/107] [SPARK-13195][STREAMING] Fix NoSuchElementException when a state is not set but timeoutThreshold is defined Check the state Existence before calling get. Author: Shixiong Zhu Closes #11081 from zsxwing/SPARK-13195. --- .../org/apache/spark/streaming/rdd/MapWithStateRDD.scala | 3 ++- .../apache/spark/streaming/rdd/MapWithStateRDDSuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 1d2244eaf22b3..6ab1956bed900 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -57,7 +57,8 @@ private[streaming] object MapWithStateRDDRecord { val returned = mappingFunction(batchTime, key, Some(value), wrappedState) if (wrappedState.isRemoved) { newStateMap.remove(key) - } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) { + } else if (wrappedState.isUpdated + || (wrappedState.exists && timeoutThresholdTime.isDefined)) { newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) } mappedData ++= returned diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala index 5b13fd6ad611a..e8c814ba7184b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -190,6 +190,11 @@ class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with B timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123)) + // If a state is not set but timeoutThreshold is defined, we should ignore this state. + // Previously it threw NoSuchElementException (SPARK-13195). + assertRecordUpdate(initStates = Seq(), data = Seq("noop"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Nil) } test("states generated by MapWithStateRDD") { From 7a4b37f02cffd6d971c07716688a7cb6cee26c8b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 12:47:32 -0800 Subject: [PATCH 089/107] [HOTFIX] Fix style violation caused by c756bda --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 8ed7553258e7b..b82196c86b637 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -62,8 +62,9 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt - private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") - .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") + private[this] val shutdownTimeoutMS = + conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") // Synchronization protected by stateLock private[this] var stopCalled: Boolean = false From 6dbfc40776514c3a5667161ebe7829f4cc9c7529 Mon Sep 17 00:00:00 2001 From: Raafat Akkad Date: Thu, 4 Feb 2016 16:09:31 -0800 Subject: [PATCH 090/107] [SPARK-13052] waitingApps metric doesn't show the number of apps currently in the WAITING state Author: Raafat Akkad Closes #10959 from RaafatAkkad/master. --- core/src/main/scala/org/apache/spark/deploy/master/Master.scala | 2 +- .../scala/org/apache/spark/deploy/master/MasterSource.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 202a1b787c21b..0f11f680b3914 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -74,7 +74,7 @@ private[deploy] class Master( val workers = new HashSet[WorkerInfo] val idToApp = new HashMap[String, ApplicationInfo] - val waitingApps = new ArrayBuffer[ApplicationInfo] + private val waitingApps = new ArrayBuffer[ApplicationInfo] val apps = new HashSet[ApplicationInfo] private val idToWorker = new HashMap[String, WorkerInfo] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 66a9ff38678c6..39b2647a900f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -42,6 +42,6 @@ private[spark] class MasterSource(val master: Master) extends Source { // Gauge for waiting application numbers in cluster metricRegistry.register(MetricRegistry.name("waitingApps"), new Gauge[Int] { - override def getValue: Int = master.waitingApps.size + override def getValue: Int = master.apps.filter(_.state == ApplicationState.WAITING).size }) } From e3c75c6398b1241500343ff237e9bcf78b5396f9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 4 Feb 2016 18:37:58 -0800 Subject: [PATCH 091/107] [SPARK-12850][SQL] Support Bucket Pruning (Predicate Pushdown for Bucketed Tables) JIRA: https://issues.apache.org/jira/browse/SPARK-12850 This PR is to support bucket pruning when the predicates are `EqualTo`, `EqualNullSafe`, `IsNull`, `In`, and `InSet`. Like HIVE, in this PR, the bucket pruning works when the bucketing key has one and only one column. So far, I do not find a way to verify how many buckets are actually scanned. However, I did verify it when doing the debug. Could you provide a suggestion how to do it properly? Thank you! cloud-fan yhuai rxin marmbrus BTW, we can add more cases to support complex predicate including `Or` and `And`. Please let me know if I should do it in this PR. Maybe we also need to add test cases to verify if bucket pruning works well for each data type. Author: gatorsmile Closes #10942 from gatorsmile/pruningBuckets. --- .../datasources/DataSourceStrategy.scala | 78 ++++++++- .../apache/spark/sql/sources/interfaces.scala | 15 +- .../spark/sql/sources/BucketedReadSuite.scala | 162 +++++++++++++++++- 3 files changed, 245 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index da9320ffb61c3..c24967abeb33e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.collection.BitSet /** * A Strategy for planning scans over data sources defined using the sources API. @@ -97,10 +99,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (partitionAndNormalColumnAttrs ++ projects).toSeq } + // Prune the buckets based on the pushed filters that do not contain partitioning key + // since the bucketing key is not allowed to use the columns in partitioning key + val bucketSet = getBuckets(pushedFilters, t.getBucketSpec) + val scan = buildPartitionedTableScan( l, partitionAndNormalColumnProjs, pushedFilters, + bucketSet, t.partitionSpec.partitionColumns, selectedPartitions) @@ -124,11 +131,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) + // Prune the buckets based on the filters + val bucketSet = getBuckets(filters, t.getBucketSpec) pruneFilterProject( l, projects, filters, - (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil + (a, f) => + t.buildInternalScan(a.map(_.name).toArray, f, bucketSet, t.paths, confBroadcast)) :: Nil case l @ LogicalRelation(baseRelation: TableScan, _, _) => execution.PhysicalRDD.createFromDataSource( @@ -150,6 +160,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { logicalRelation: LogicalRelation, projections: Seq[NamedExpression], filters: Seq[Expression], + buckets: Option[BitSet], partitionColumns: StructType, partitions: Array[Partition]): SparkPlan = { val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] @@ -174,7 +185,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // assuming partition columns data stored in data files are always consistent with those // partition values encoded in partition directory paths. val dataRows = relation.buildInternalScan( - requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast) + requiredDataColumns.map(_.name).toArray, filters, buckets, Array(dir), confBroadcast) // Merges data values with partition values. mergeWithPartitionValues( @@ -251,6 +262,69 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } } + // Get the bucket ID based on the bucketing values. + // Restriction: Bucket pruning works iff the bucketing column has one and only one column. + def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType)) + mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + val bucketIdGeneration = UnsafeProjection.create( + HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, + bucketColumn :: Nil) + + bucketIdGeneration(mutableRow).getInt(0) + } + + // Get the bucket BitSet by reading the filters that only contains bucketing keys. + // Note: When the returned BitSet is None, no pruning is possible. + // Restriction: Bucket pruning works iff the bucketing column has one and only one column. + private def getBuckets( + filters: Seq[Expression], + bucketSpec: Option[BucketSpec]): Option[BitSet] = { + + if (bucketSpec.isEmpty || + bucketSpec.get.numBuckets == 1 || + bucketSpec.get.bucketColumnNames.length != 1) { + // None means all the buckets need to be scanned + return None + } + + // Just get the first because bucketing pruning only works when the column has one column + val bucketColumnName = bucketSpec.get.bucketColumnNames.head + val numBuckets = bucketSpec.get.numBuckets + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.clear() + + filters.foreach { + case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, v)) + case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, v)) + case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, v)) + case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, v)) + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case expressions.In(a: Attribute, list) + if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + val hSet = list.map(e => e.eval(EmptyRow)) + hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e))) + case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, null)) + case _ => + } + + logInfo { + val selected = matchedBuckets.cardinality() + val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100 + s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions." + } + + // None means all the buckets need to be scanned + if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) + } + protected def prunePartitions( predicates: Seq[Expression], partitionSpec: PartitionSpec): Seq[Partition] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 299fc6efbb046..737be7dfd12f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection.BitSet /** * ::DeveloperApi:: @@ -722,6 +723,7 @@ abstract class HadoopFsRelation private[sql]( final private[sql] def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], + bucketSet: Option[BitSet], inputPaths: Array[String], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val inputStatuses = inputPaths.flatMap { input => @@ -743,9 +745,16 @@ abstract class HadoopFsRelation private[sql]( // id from file name. Then read these files into a RDD(use one-partition empty RDD for empty // bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result. val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId => - groupedBucketFiles.get(bucketId).map { inputStatuses => - buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1) - }.getOrElse(sqlContext.emptyResult) + // If the current bucketId is not set in the bucket bitSet, skip scanning it. + if (bucketSet.nonEmpty && !bucketSet.get.get(bucketId)){ + sqlContext.emptyResult + } else { + // When all the buckets need a scan (i.e., bucketSet is equal to None) + // or when the current bucket need a scan (i.e., the bit of bucketId is set to true) + groupedBucketFiles.get(bucketId).map { inputStatuses => + buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1) + }.getOrElse(sqlContext.emptyResult) + } } new UnionRDD(sqlContext.sparkContext, perBucketRows) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 150d0c748631e..9ba645626fe72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -19,22 +19,28 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.Exchange -import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.execution.{Exchange, PhysicalRDD} +import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy} import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils +import org.apache.spark.util.collection.BitSet class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private val nullDF = (for { + i <- 0 to 50 + s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") + } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + test("read bucketed data") { - val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") withTable("bucketed_table") { df.write .format("parquet") @@ -59,6 +65,152 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } + // To verify if the bucket pruning works, this function checks two conditions: + // 1) Check if the pruned buckets (before filtering) are empty. + // 2) Verify the final result is the same as the expected one + private def checkPrunedAnswers( + bucketSpec: BucketSpec, + bucketValues: Seq[Integer], + filterCondition: Column, + originalDataFrame: DataFrame): Unit = { + + val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k") + val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec + // Limit: bucket pruning only works when the bucket column has one and only one column + assert(bucketColumnNames.length == 1) + val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) + val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + } + + // Filter could hide the bug in bucket pruning. Thus, skipping all the filters + val rdd = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan + .find(_.isInstanceOf[PhysicalRDD]) + assert(rdd.isDefined) + + val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty) + } + // checking if all the pruned buckets are empty + assert(checkedResult.collect().forall(_ == true)) + + checkAnswer( + bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), + originalDataFrame.filter(filterCondition).orderBy("i", "j", "k")) + } + + test("read partitioning bucketed tables with bucket pruning filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + // Case 1: EqualTo + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j, + df) + + // Case 2: EqualNullSafe + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" <=> j, + df) + + // Case 3: In + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1, j + 2, j + 3), + filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), + df) + } + } + } + + test("read non-partitioning bucketed tables with bucket pruning filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j, + df) + } + } + } + + test("read partitioning bucketed tables having null in bucketing key") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + nullDF.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + // Case 1: isNull + checkPrunedAnswers( + bucketSpec, + bucketValues = null :: Nil, + filterCondition = $"j".isNull, + nullDF) + + // Case 2: <=> null + checkPrunedAnswers( + bucketSpec, + bucketValues = null :: Nil, + filterCondition = $"j" <=> null, + nullDF) + } + } + + test("read partitioning bucketed tables having composite filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j && $"k" > $"j", + df) + + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j && $"i" > j % 5, + df) + } + } + } + private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") From 352102ed0b7be8c335553d7e0389fd7ce83f5fbf Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Thu, 4 Feb 2016 22:22:41 -0800 Subject: [PATCH 092/107] [SPARK-13208][CORE] Replace use of Pairs with Tuple2s Another trivial deprecation fix for Scala 2.11 Author: Jakob Odersky Closes #11089 from jodersky/SPARK-13208. --- .../scala/org/apache/spark/api/java/JavaDoubleRDD.scala | 2 +- .../scala/org/apache/spark/rdd/DoubleRDDFunctions.scala | 4 ++-- .../execution/datasources/parquet/ParquetSchemaSuite.scala | 2 +- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 37ae007f880c2..13e18a56c8fd8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -230,7 +230,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = { + def histogram(bucketCount: Int): (Array[scala.Double], Array[Long]) = { val result = srdd.histogram(bucketCount) (result._1, result._2) } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index c07f346bbafd5..bd61d04d42f05 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -103,7 +103,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope { + def histogram(bucketCount: Int): (Array[Double], Array[Long]) = self.withScope { // Scala's built-in range has issues. See #SI-8782 def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { val span = max - min @@ -112,7 +112,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { // Compute the minimum and the maximum val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(Double.NegativeInfinity, - Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => + Double.PositiveInfinity)((e: Double, x: (Double, Double)) => (x._1.max(e), x._2.min(e)))) }.reduce { (maxmin1, maxmin2) => (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index d860651d421f8..90e3d50714ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -261,7 +261,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { int96AsTimestamp = true, writeLegacyParquetFormat = true) - testSchemaInference[Tuple1[Pair[Int, String]]]( + testSchemaInference[Tuple1[(Int, String)]]( "struct", """ |message root { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 9632d27a2ffce..1337a25eb26a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -770,14 +770,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) - .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} + .zipWithIndex.map {case ((value, attr), key) => HavingRow(key, value, attr)} TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() - .map(x => Pair(x.getString(0), x.getInt(1))) + .map(x => (x.getString(0), x.getInt(1))) - assert(results === Array(Pair("foo", 4))) + assert(results === Array(("foo", 4))) TestHive.reset() } From 82d84ff2dd3efb3bda20b529f09a4022586fb722 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Feb 2016 22:43:44 -0800 Subject: [PATCH 093/107] [SPARK-13187][SQL] Add boolean/long/double options in DataFrameReader/Writer This patch adds option function for boolean, long, and double types. This makes it slightly easier for Spark users to specify options without turning them into strings. Using the JSON data source as an example. Before this patch: ```scala sqlContext.read.option("primitivesAsString", "true").json("/path/to/json") ``` After this patch: Before this patch: ```scala sqlContext.read.option("primitivesAsString", true).json("/path/to/json") ``` Author: Reynold Xin Closes #11072 from rxin/SPARK-13187. --- .../apache/spark/sql/DataFrameReader.scala | 21 ++++++++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 21 ++++++++++++++++ .../sql/streaming/DataStreamReaderSuite.scala | 25 +++++++++++++++++++ 3 files changed, 67 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a58643a5ba154..962fdadf1431d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -78,6 +78,27 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { this } + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataFrameReader = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataFrameReader = option(key, value.toString) + /** * (Scala-specific) Adds input options for the underlying data source. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 80447fefe1f2a..8060198968988 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -95,6 +95,27 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataFrameWriter = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataFrameWriter = option(key, value.toString) + /** * (Scala-specific) Adds output options for the underlying data source. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala index b36b41cac9b4f..95a17f338d04d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala @@ -162,4 +162,29 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { assert(LastOptions.parameters("path") == "/test") } + test("test different data types for options") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .stream("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + + LastOptions.parameters = null + df.write + .format("org.apache.spark.sql.streaming.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .stream("/test") + .stop() + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + } } From 7b73f1719cff233645c7850a5dbc8ed2dc9c9a58 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 5 Feb 2016 13:44:34 -0800 Subject: [PATCH 094/107] [SPARK-13166][SQL] Rename DataStreamReaderWriterSuite to DataFrameReaderWriterSuite A follow up PR for #11062 because it didn't rename the test suite. Author: Shixiong Zhu Closes #11096 from zsxwing/rename. --- ...StreamReaderSuite.scala => DataFrameReaderWriterSuite.scala} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{DataStreamReaderSuite.scala => DataFrameReaderWriterSuite.scala} (98%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 95a17f338d04d..36212e4395985 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -56,7 +56,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { +class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext { import testImplicits._ test("resolve default source") { From 1ed354a5362967d904e9513e5a1618676c9c67a6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Feb 2016 14:34:12 -0800 Subject: [PATCH 095/107] [SPARK-12939][SQL] migrate encoder resolution logic to Analyzer https://issues.apache.org/jira/browse/SPARK-12939 Now we will catch `ObjectOperator` in `Analyzer` and resolve the `fromRowExpression/deserializer` inside it. Also update the `MapGroups` and `CoGroup` to pass in `dataAttributes`, so that we can correctly resolve value deserializer(the `child.output` contains both groupking key and values, which may mess things up if they have same-name attribtues). End-to-end tests are added. follow-ups: * remove encoders from typed aggregate expression. * completely remove resolve/bind in `ExpressionEncoder` Author: Wenchen Fan Closes #10852 from cloud-fan/bug. --- .../sql/catalyst/analysis/Analyzer.scala | 82 +++++++++++++------ .../catalyst/encoders/ExpressionEncoder.scala | 50 +++++------ .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 9 +- .../sql/catalyst/plans/logical/object.scala | 68 +++++++++++---- .../encoders/EncoderResolutionSuite.scala | 8 +- .../encoders/ExpressionEncoderSuite.scala | 5 +- .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../org/apache/spark/sql/GroupedDataset.scala | 3 + .../spark/sql/execution/SparkStrategies.scala | 9 +- .../apache/spark/sql/execution/objects.scala | 31 +++---- .../org/apache/spark/sql/DatasetSuite.scala | 64 +++++++++++++-- 12 files changed, 230 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b59eb12419c45..cb228cf52b433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ @@ -457,25 +458,34 @@ class Analyzer( // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => - val newOrdering = resolveSortOrders(ordering, child, throws = false) + val newOrdering = + ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, join, outer, qualifier, output, child) if child.resolved && !generator.resolved => - val newG = generator transformUp { - case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } - case UnresolvedExtractValue(child, fieldExpr) => - ExtractValue(child, fieldExpr, resolver) - } + val newG = resolveExpression(generator, child, throws = true) if (newG.fastEquals(generator)) { g } else { Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } + // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator + // should be resolved by their corresponding attributes instead of children's output. + case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) => + val deserializerToAttributes = o.deserializers.map { + case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes + }.toMap + + o.transformExpressions { + case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes => + resolveDeserializer(expr, attributes) + }.getOrElse(expr) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { @@ -490,6 +500,32 @@ class Analyzer( } } + private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = { + exprs.exists { expr => + !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined + } + } + + def resolveDeserializer( + deserializer: Expression, + attributes: Seq[Attribute]): Expression = { + val unbound = deserializer transform { + case b: BoundReference => attributes(b.ordinal) + } + + resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + n.copy(outerPointer = Some(Literal.fromObject(outer))) + } + } + def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { case a: Alias => Alias(a.child, a.name)() @@ -508,23 +544,20 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { - ordering.map { order => - // Resolve SortOrder in one round. - // If throws == false or the desired attribute doesn't exist - // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. - // Else, throw exception. - try { - val newOrder = order transformUp { - case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).getOrElse(u) - case UnresolvedExtractValue(child, fieldName) if child.resolved => - ExtractValue(child, fieldName, resolver) - } - newOrder.asInstanceOf[SortOrder] - } catch { - case a: AnalysisException if !throws => order + private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = { + // Resolve expression in one round. + // If throws == false or the desired attribute doesn't exist + // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. + // Else, throw exception. + try { + expr transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) } + } catch { + case a: AnalysisException if !throws => expr } } @@ -619,7 +652,8 @@ class Analyzer( ordering: Seq[SortOrder], plan: LogicalPlan, child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, child, throws = false) + val newOrdering = + ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 64832dc114e67..58f6d0eb9e929 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -50,7 +50,7 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(typeTag[T].tpe) val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] @@ -257,12 +257,10 @@ case class ExpressionEncoder[T]( } /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the - * given schema. + * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce + * friendly error messages to explain why it fails to resolve if there is something wrong. */ - def resolve( - schema: Seq[Attribute], - outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { + def validate(schema: Seq[Attribute]): Unit = { def fail(st: StructType, maxOrdinal: Int): Unit = { throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + "but failed as the number of fields does not line up.\n" + @@ -270,6 +268,8 @@ case class ExpressionEncoder[T]( " - Target schema: " + this.schema.simpleString) } + // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all + // `BoundReference`, make sure their ordinals are all valid. var maxOrdinal = -1 fromRowExpression.foreach { case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal @@ -279,6 +279,10 @@ case class ExpressionEncoder[T]( fail(StructType.fromAttributes(schema), maxOrdinal) } + // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of + // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. + // Note that, `BoundReference` contains the expected type, but here we need the actual type, so + // we unbound it by the given `schema` and propagate the actual type to `GetStructField`. val unbound = fromRowExpression transform { case b: BoundReference => schema(b.ordinal) } @@ -299,28 +303,24 @@ case class ExpressionEncoder[T]( fail(schema, maxOrdinal) } } + } - val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the + * given schema. + */ + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { + val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer( + fromRowExpression, schema) + + // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check + // analysis, go through optimizer, etc. + val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) - val optimizedPlan = SimplifyCasts(analyzedPlan) - - // In order to construct instances of inner classes (for example those declared in a REPL cell), - // we need an instance of the outer scope. This rule substitues those outer objects into - // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` - // registry. - copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { - case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => - val outer = outerScopes.get(n.cls.getDeclaringClass.getName) - if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + - s"to the scope that this class was defined in. " + "" + - "Try moving this class out of its parent class.") - } - - n.copy(outerPointer = Some(Literal.fromObject(outer))) - }) + copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 89d40b3b2c141..d8f755a39c7ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -154,7 +154,7 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(BoundReference(i, f.dataType, f.nullable)) + constructorFor(field) ) } CreateExternalRow(fields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a1ac93073916c..902e18081bddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -119,10 +119,13 @@ object SamplePushDown extends Rule[LogicalPlan] { */ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case m @ MapPartitions(_, input, _, child: ObjectOperator) - if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType => + case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => val childWithoutSerialization = child.withObjectOutput - m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization) + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 760348052739c..3f97662957b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{ObjectType, StructType} /** * A trait for logical operators that apply user defined functions to domain objects. @@ -30,6 +30,15 @@ trait ObjectOperator extends LogicalPlan { /** The serializer that is used to produce the output of this operator. */ def serializer: Seq[NamedExpression] + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + /** + * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects. + * It must also provide the attributes that are available during the resolution of each + * deserializer. + */ + def deserializers: Seq[(Expression, Seq[Attribute])] + /** * The object type that is produced by the user defined function. Note that the return type here * is the same whether or not the operator is output serialized data. @@ -44,13 +53,13 @@ trait ObjectOperator extends LogicalPlan { def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { this } else { - withNewSerializer(outputObject) + withNewSerializer(outputObject :: Nil) } /** Returns a copy of this operator with a different serializer. */ - def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy { + def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy { productIterator.map { - case c if c == serializer => newSerializer :: Nil + case c if c == serializer => newSerializer case other: AnyRef => other }.toArray } @@ -70,15 +79,16 @@ object MapPartitions { /** * A relation produced by applying `func` to each partition of the `child`. - * @param input used to extract the input to `func` from an input row. + * + * @param deserializer used to extract the input to `func` from an input row. * @param serializer use to serialize the output of `func`. */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `AppendColumn` nodes. */ @@ -97,16 +107,21 @@ object AppendColumns { /** * A relation produced by applying `func` to each partition of the `child`, concatenating the * resulting columns at the end of the input row. - * @param input used to extract the input to `func` from an input row. + * + * @param deserializer used to extract the input to `func` from an input row. * @param serializer use to serialize the output of `func`. */ case class AppendColumns( func: Any => Any, - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = child.output ++ newColumns + def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) + + override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `MapGroups` nodes. */ @@ -114,6 +129,7 @@ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], @@ -121,6 +137,7 @@ object MapGroups { encoderFor[T].fromRowExpression, encoderFor[U].namedExpressions, groupingAttributes, + dataAttributes, child) } } @@ -129,19 +146,22 @@ object MapGroups { * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. * Func is invoked with an object representation of the grouping key an iterator containing the * object representation of all the rows with that key. - * @param keyObject used to extract the key object for each group. - * @param input used to extract the items in the iterator from an input row. + * + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. * @param serializer use to serialize the output of `func`. */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - input: Expression, + keyDeserializer: Expression, + valueDeserializer: Expression, serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: LogicalPlan) extends UnaryNode with ObjectOperator { - def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = + Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes) } /** Factory for constructing new `CoGroup` nodes. */ @@ -150,8 +170,12 @@ object CoGroup { func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftData: Seq[Attribute], + rightData: Seq[Attribute], left: LogicalPlan, right: LogicalPlan): CoGroup = { + require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) + CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], encoderFor[Key].fromRowExpression, @@ -160,6 +184,8 @@ object CoGroup { encoderFor[Result].namedExpressions, leftGroup, rightGroup, + leftData, + rightData, left, right) } @@ -171,15 +197,21 @@ object CoGroup { */ case class CoGroup( func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - leftObject: Expression, - rightObject: Expression, + keyDeserializer: Expression, + leftDeserializer: Expression, + rightDeserializer: Expression, serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectOperator { + override def producedAttributes: AttributeSet = outputSet - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = + // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve + // the `keyDeserializer` based on either of them, here we pick the left one. + Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index bc36a55ae0ea2..92a68a4dba915 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -127,7 +127,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.long, 'c.int) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct\n" + @@ -136,7 +136,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct\n" + @@ -149,7 +149,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct>\n" + @@ -158,7 +158,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long)) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct>\n" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 88c558d80a79a..e00060f9b6aff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -19,13 +19,10 @@ package org.apache.spark.sql.catalyst.encoders import java.sql.{Date, Timestamp} import java.util.Arrays -import java.util.concurrent.ConcurrentMap import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag -import com.google.common.collect.MapMaker - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} @@ -78,7 +75,7 @@ class JavaSerializable(val value: Int) extends Serializable { } class ExpressionEncoderSuite extends SparkFunSuite { - OuterScopes.outerScopes.put(getClass.getName, this) + OuterScopes.addOuterScope(this) implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f182270a08729..378763268acc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,6 +74,7 @@ class Dataset[T] private[sql]( * same object type (that will be possibly resolved to a different schema). */ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + unresolvedTEncoder.validate(logicalPlan.output) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = @@ -85,7 +86,7 @@ class Dataset[T] private[sql]( */ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) - private implicit def classTag = resolvedTEncoder.clsTag + private implicit def classTag = unresolvedTEncoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b3f8284364782..c0e28f2dc5bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -116,6 +116,7 @@ class GroupedDataset[K, V] private[sql]( MapGroups( f, groupingAttributes, + dataAttributes, logicalPlan)) } @@ -310,6 +311,8 @@ class GroupedDataset[K, V] private[sql]( f, this.groupingAttributes, other.groupingAttributes, + this.dataAttributes, + other.dataAttributes, this.logicalPlan, other.logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9293e55141757..830bb011beab4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -306,11 +306,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.MapPartitions(f, in, out, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil - case logical.MapGroups(f, key, in, out, grouping, child) => - execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil - case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) => + case logical.MapGroups(f, key, in, out, grouping, data, child) => + execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil + case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) => execution.CoGroup( - f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil + f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, + planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 2acca1743cbb9..582dda8603f4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -53,14 +53,14 @@ trait ObjectOperator extends SparkPlan { */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with ObjectOperator { override def output: Seq[Attribute] = serializer.map(_.toAttribute) override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(input, child.output) + val getObject = generateToObject(deserializer, child.output) val outputObject = generateToRow(serializer) func(iter.map(getObject)).map(outputObject) } @@ -72,7 +72,7 @@ case class MapPartitions( */ case class AppendColumns( func: Any => Any, - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with ObjectOperator { @@ -82,7 +82,7 @@ case class AppendColumns( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(input, child.output) + val getObject = generateToObject(deserializer, child.output) val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) val outputObject = generateToRow(serializer) @@ -103,10 +103,11 @@ case class AppendColumns( */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - input: Expression, + keyDeserializer: Expression, + valueDeserializer: Expression, serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: SparkPlan) extends UnaryNode with ObjectOperator { override def output: Seq[Attribute] = serializer.map(_.toAttribute) @@ -121,8 +122,8 @@ case class MapGroups( child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val getKey = generateToObject(keyObject, groupingAttributes) - val getValue = generateToObject(input, child.output) + val getKey = generateToObject(keyDeserializer, groupingAttributes) + val getValue = generateToObject(valueDeserializer, dataAttributes) val outputObject = generateToRow(serializer) grouped.flatMap { case (key, rowIter) => @@ -142,12 +143,14 @@ case class MapGroups( */ case class CoGroup( func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - leftObject: Expression, - rightObject: Expression, + keyDeserializer: Expression, + leftDeserializer: Expression, + rightDeserializer: Expression, serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: SparkPlan, right: SparkPlan) extends BinaryNode with ObjectOperator { @@ -164,9 +167,9 @@ case class CoGroup( val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val getKey = generateToObject(keyObject, leftGroup) - val getLeft = generateToObject(leftObject, left.output) - val getRight = generateToObject(rightObject, right.output) + val getKey = generateToObject(keyDeserializer, leftGroup) + val getLeft = generateToObject(leftDeserializer, leftAttr) + val getRight = generateToObject(rightDeserializer, rightAttr) val outputObject = generateToRow(serializer) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b69bb21db532b..374f4320a9239 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.postfixOps +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -45,13 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } - test("SPARK-12404: Datatype Helper Serializablity") { val ds = sparkContext.parallelize(( - new Timestamp(0), - new Date(0), - java.math.BigDecimal.valueOf(1), - scala.math.BigDecimal(1)) :: Nil).toDS() + new Timestamp(0), + new Date(0), + java.math.BigDecimal.valueOf(1), + scala.math.BigDecimal(1)) :: Nil).toDS() ds.collect() } @@ -523,7 +523,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("verify mismatching field names fail with a good error") { val ds = Seq(ClassData("a", 1)).toDS() val e = intercept[AnalysisException] { - ds.as[ClassData2].collect() + ds.as[ClassData2] } assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) } @@ -567,6 +567,58 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) checkAnswer(ds1.toDF(), Row(Row(null))) } + + test("support inner class in Dataset") { + val outer = new OuterClass + OuterScopes.addOuterScope(outer) + val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS() + checkAnswer(ds.map(_.a), "1", "2") + } + + test("grouping key and grouped value has field with same name") { + val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS() + val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups { + case (key, values) => key.a + values.map(_.b).sum + } + + checkAnswer(agged, "a3") + } + + test("cogroup's left and right side has field with same name") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS() + val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) { + case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum) + } + + checkAnswer(cogrouped, "a13", "b24") + } + + test("give nice error message when the real number of fields doesn't match encoder schema") { + val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + + val message = intercept[AnalysisException] { + ds.as[(String, Int, Long)] + }.message + assert(message == + "Try to map struct to Tuple3, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:int,_3:bigint>") + + val message2 = intercept[AnalysisException] { + ds.as[Tuple1[String]] + }.message + assert(message2 == + "Try to map struct to Tuple1, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string>") + } +} + +class OuterClass extends Serializable { + case class InnerClass(a: String) } case class ClassData(a: String, b: Int) From 66e1383de2650a0f06929db8109a02e32c5eaf6b Mon Sep 17 00:00:00 2001 From: Bill Chambers Date: Fri, 5 Feb 2016 14:35:39 -0800 Subject: [PATCH 096/107] [SPARK-13214][DOCS] update dynamicAllocation documentation Author: Bill Chambers Closes #11094 from anabranch/dynamic-docs. --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 93b399d819ccd..cd9dc1bcfc113 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1169,8 +1169,8 @@ Apart from these, the following properties are also available, and may be useful false Whether to use dynamic resource allocation, which scales the number of executors registered - with this application up and down based on the workload. Note that this is currently only - available on YARN mode. For more detail, see the description + with this application up and down based on the workload. + For more detail, see the description here.

        This requires spark.shuffle.service.enabled to be set. From 0bb5b73387e60ef007b415fba69a3e1e89a4f013 Mon Sep 17 00:00:00 2001 From: Luc Bourlier Date: Fri, 5 Feb 2016 14:37:42 -0800 Subject: [PATCH 097/107] [SPARK-13002][MESOS] Send initial request of executors for dyn allocation Fix for [SPARK-13002](https://issues.apache.org/jira/browse/SPARK-13002) about the initial number of executors when running with dynamic allocation on Mesos. Instead of fixing it just for the Mesos case, made the change in `ExecutorAllocationManager`. It is already driving the number of executors running on Mesos, only no the initial value. The `None` and `Some(0)` are internal details on the computation of resources to reserved, in the Mesos backend scheduler. `executorLimitOption` has to be initialized correctly, otherwise the Mesos backend scheduler will, either, create to many executors at launch, or not create any executors and not be able to recover from this state. Removed the 'special case' description in the doc. It was not totally accurate, and is not needed anymore. This doesn't fix the same problem visible with Spark standalone. There is no straightforward way to send the initial value in standalone mode. Somebody knowing this part of the yarn support should review this change. Author: Luc Bourlier Closes #11047 from skyluc/issue/initial-dyn-alloc-2. --- .../mesos/CoarseMesosSchedulerBackend.scala | 13 +++++++++--- docs/running-on-mesos.md | 21 ++++++++----------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b82196c86b637..0a2d72f4dcb4b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -87,10 +87,17 @@ private[spark] class CoarseMesosSchedulerBackend( val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] /** - * The total number of executors we aim to have. Undefined when not using dynamic allocation - * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]]. + * The total number of executors we aim to have. Undefined when not using dynamic allocation. + * Initially set to 0 when using dynamic allocation, the executor allocation manager will send + * the real initial limit later. */ - private var executorLimitOption: Option[Int] = None + private var executorLimitOption: Option[Int] = { + if (Utils.isDynamicAllocationEnabled(conf)) { + Some(0) + } else { + None + } + } /** * Return the current executor limit, which may be [[Int.MaxValue]] diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 0ef1ccb36e117..e1c87a8d95a32 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -246,18 +246,15 @@ In either case, HDFS runs separately from Hadoop MapReduce, without being schedu # Dynamic Resource Allocation with Mesos -Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics -of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down -since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler -can scale back up to the same amount of executors when Spark signals more executors are needed. - -Users that like to utilize this feature should launch the Mesos Shuffle Service that -provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's -termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh -scripts accordingly. - -The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos -is to launch the Shuffle Service with Marathon with a unique host constraint. +Mesos supports dynamic allocation only with coarse-grain mode, which can resize the number of +executors based on statistics of the application. For general information, +see [Dynamic Resource Allocation](job-scheduling.html#dynamic-resource-allocation). + +The External Shuffle Service to use is the Mesos Shuffle Service. It provides shuffle data cleanup functionality +on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch it, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all slave nodes, with `spark.shuffle.service.enabled` set to `true`. + +This can also be achieved through Marathon, using a unique host constraint, and the following command: `bin/spark-class org.apache.spark.deploy.mesos.MesosExternalShuffleService`. # Configuration From 875f5079290a06c12d44de657dfeabc913e4a303 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Feb 2016 15:07:43 -0800 Subject: [PATCH 098/107] [SPARK-13215] [SQL] remove fallback in codegen Since we remove the configuration for codegen, we are heavily reply on codegen (also TungstenAggregate require the generated MutableProjection to update UnsafeRow), should remove the fallback, which could make user confusing, see the discussion in SPARK-13116. Author: Davies Liu Closes #11097 from davies/remove_fallback. --- .../spark/sql/execution/SparkPlan.scala | 36 ++----------------- .../spark/sql/execution/aggregate/udaf.scala | 12 ++----- .../spark/sql/execution/local/LocalNode.scala | 26 ++------------ 3 files changed, 8 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b19b772409d83..3cc99d3c7b1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -200,47 +200,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ inputSchema: Seq[Attribute], useSubexprElimination: Boolean = false): () => MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") - try { - GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } + GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } + GeneratePredicate.generate(expression, inputSchema) } protected def newOrdering( order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { - try { - GenerateOrdering.generate(order, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate ordering, fallback to interpreted", e) - new InterpretedOrdering(order, inputSchema) - } - } + GenerateOrdering.generate(order, inputSchema) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 5a19920add717..812e696338362 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedMutableProjection, MutableRow} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, ImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ @@ -361,13 +361,7 @@ private[sql] case class ScalaUDAF( val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - try { - GenerateMutableProjection.generate(children, inputAttributes)() - } catch { - case e: Exception => - log.error("Failed to generate mutable projection, fallback to interpreted", e) - new InterpretedMutableProjection(children, inputAttributes) - } + GenerateMutableProjection.generate(children, inputAttributes)() } private[this] lazy val inputToScalaConverters: Any => Any = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index a0dfe996ccd55..8726e4878106d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.local -import scala.util.control.NonFatal - import org.apache.spark.Logging import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -96,33 +94,13 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema") - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } + GenerateMutableProjection.generate(expressions, inputSchema) } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } + GeneratePredicate.generate(expression, inputSchema) } } From 6883a5120c6b3a806024dfad03b8bf7371c6c0eb Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Fri, 5 Feb 2016 19:00:12 -0800 Subject: [PATCH 099/107] [SPARK-13171][CORE] Replace future calls with Future Trivial search-and-replace to eliminate deprecation warnings in Scala 2.11. Also works with 2.10 Author: Jakob Odersky Closes #11085 from jodersky/SPARK-13171. --- .../spark/deploy/FaultToleranceTest.scala | 8 ++++---- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/JobCancellationSuite.scala | 18 +++++++++--------- .../ShuffleBlockFetcherIteratorSuite.scala | 6 +++--- .../expressions/CodeGenerationSuite.scala | 2 +- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../joins/BroadcastHashOuterJoin.scala | 2 +- .../thriftserver/HiveThriftServer2Suites.scala | 6 +++--- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index c0ede4b7c8734..15d220d01b461 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -22,7 +22,7 @@ import java.net.URL import java.util.concurrent.TimeoutException import scala.collection.mutable.ListBuffer -import scala.concurrent.{future, promise, Await} +import scala.concurrent.{Await, Future, Promise} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.language.postfixOps @@ -249,7 +249,7 @@ private object FaultToleranceTest extends App with Logging { /** This includes Client retry logic, so it may take a while if the cluster is recovering. */ private def assertUsable() = { - val f = future { + val f = Future { try { val res = sc.parallelize(0 until 10).collect() assertTrue(res.toList == (0 until 10)) @@ -283,7 +283,7 @@ private object FaultToleranceTest extends App with Logging { numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1 } - val f = future { + val f = Future { try { while (!stateValid()) { Thread.sleep(1000) @@ -405,7 +405,7 @@ private object SparkDocker { } private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = { - val ipPromise = promise[String]() + val ipPromise = Promise[String]() val outFile = File.createTempFile("fault-tolerance-test", "", Utils.createTempDir()) val outStream: FileWriter = new FileWriter(outFile) def findIpAndLog(line: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 179d3b9f20b1f..df3c286a0a66f 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -394,7 +394,7 @@ private[deploy] class Worker( // rpcEndpoint. // Copy ids so that it can be used in the cleanup thread. val appIds = executors.values.map(_.appId).toSet - val cleanupFuture = concurrent.future { + val cleanupFuture = concurrent.Future { val appDirs = workDir.listFiles() if (appDirs == null) { throw new IOException("ERROR: Failed to list files in " + appDirs) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index e13a442463e8d..c347ab8dc8020 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.Semaphore import scala.concurrent.Await import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.future +import scala.concurrent.Future import org.scalatest.BeforeAndAfter import org.scalatest.Matchers @@ -103,7 +103,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val rdd1 = rdd.map(x => x) - future { + Future { taskStartedSemaphore.acquire() sc.cancelAllJobs() taskCancelledSemaphore.release(100000) @@ -126,7 +126,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) // jobA is the one to be cancelled. - val jobA = future { + val jobA = Future { sc.setJobGroup("jobA", "this is a job to be cancelled") sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() } @@ -191,7 +191,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) // jobA is the one to be cancelled. - val jobA = future { + val jobA = Future { sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100000); i }.count() } @@ -231,7 +231,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val f2 = rdd.countAsync() // Kill one of the action. - future { + Future { sem1.acquire() f1.cancel() JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) @@ -247,7 +247,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Cancel before launching any tasks { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() - future { f.cancel() } + Future { f.cancel() } val e = intercept[SparkException] { f.get() } assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -263,7 +263,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() - future { + Future { // Wait until some tasks were launched before we cancel the job. sem.acquire() f.cancel() @@ -277,7 +277,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Cancel before launching any tasks { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) - future { f.cancel() } + Future { f.cancel() } val e = intercept[SparkException] { f.get() } assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -292,7 +292,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } }) val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) - future { + Future { sem.acquire() f.cancel() } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 828153bdbfc44..c9c2fb2691d70 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,7 +21,7 @@ import java.io.InputStream import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.future +import scala.concurrent.Future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ @@ -149,7 +149,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - future { + Future { // Return the first two blocks, and wait till task completion before returning the 3rd one listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) @@ -211,7 +211,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - future { + Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 0c42e2fc7c5e5..b5413fbe2bbcc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -36,7 +36,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { import scala.concurrent.duration._ val futures = (1 to 20).map { _ => - future { + Future { GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 8b275e886c46c..943ad31c0cef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -77,7 +77,7 @@ case class BroadcastHashJoin( // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - future { + Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. SQLExecution.withExecutionId(sparkContext, executionId) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index db8edd169dcfa..f48fc3b84864d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -76,7 +76,7 @@ case class BroadcastHashOuterJoin( // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - future { + Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. SQLExecution.withExecutionId(sparkContext, executionId) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ba3b26e1b7d49..865197e24caf8 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -23,7 +23,7 @@ import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{future, Await, ExecutionContext, Promise} +import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ import scala.io.Source import scala.util.{Random, Try} @@ -362,7 +362,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { try { // Start a very-long-running query that will take hours to finish, then cancel it in order // to demonstrate that cancellation works. - val f = future { + val f = Future { statement.executeQuery( "SELECT COUNT(*) FROM test_map " + List.fill(10)("join test_map").mkString(" ")) @@ -380,7 +380,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") try { - val sf = future { + val sf = Future { statement.executeQuery( "SELECT COUNT(*) FROM test_map " + List.fill(4)("join test_map").mkString(" ") From 4f28291f851b9062da3941e63de4eabb0c77f5d0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Feb 2016 22:40:40 -0800 Subject: [PATCH 100/107] [HOTFIX] fix float part of avgRate --- core/src/main/scala/org/apache/spark/util/Benchmark.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index d1699f5c28655..1bf6f821e9b31 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -124,7 +124,7 @@ private[spark] object Benchmark { } val best = runTimes.min val avg = runTimes.sum / iters - Result(avg / 1000000, num / (best / 1000), best / 1000000) + Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0) } } From 81da3bee669aaeb79ec68baaf7c99bff6e5d14fe Mon Sep 17 00:00:00 2001 From: Tommy YU Date: Sat, 6 Feb 2016 17:29:09 +0000 Subject: [PATCH 101/107] [SPARK-5865][API DOC] Add doc warnings for methods that return local data structures rxin srowen I work out note message for rdd.take function, please help to review. If it's fine, I can apply to all other function later. Author: Tommy YU Closes #10874 from Wenpei/spark-5865-add-warning-for-localdatastructure. --- .../apache/spark/api/java/JavaPairRDD.scala | 3 +++ .../apache/spark/api/java/JavaRDDLike.scala | 24 +++++++++++++++++++ .../apache/spark/rdd/PairRDDFunctions.scala | 3 +++ .../main/scala/org/apache/spark/rdd/RDD.scala | 15 ++++++++++++ python/pyspark/rdd.py | 17 +++++++++++++ python/pyspark/sql/dataframe.py | 6 +++++ .../org/apache/spark/sql/DataFrame.scala | 4 ++++ 7 files changed, 72 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index fb04472ee73fd..94d103588b696 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -636,6 +636,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. + * + * @note this method should only be used if the resulting data is expected to be small, as + * all the data is loaded into the driver's memory. */ def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap()) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 7340defabfe58..37c211fe706e1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -327,6 +327,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def collect(): JList[T] = rdd.collect().toSeq.asJava @@ -465,6 +468,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so * it will be slow if a lot of partitions are required. In that case, use collect() to get the * whole RDD instead. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def take(num: Int): JList[T] = rdd.take(num).toSeq.asJava @@ -548,6 +554,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD as defined by * the specified Comparator[T] and maintains the order. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements @@ -559,6 +568,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD using the * natural ordering for T and maintains the order. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. * @param num k, the number of top elements to return * @return an array of top elements */ @@ -570,6 +582,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the first k (smallest) elements from this RDD as defined by * the specified Comparator[T] and maintains the order. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. * @param num k, the number of elements to return * @param comp the comparator that defines the order * @return an array of top elements @@ -601,6 +616,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the first k (smallest) elements from this RDD using the * natural ordering for T while maintain the order. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. * @param num k, the number of top elements to return * @return an array of top elements */ @@ -634,6 +652,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * The asynchronous version of `collect`, which returns a future for * retrieving an array containing all of the elements in this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def collectAsync(): JavaFutureAction[JList[T]] = { new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava) @@ -642,6 +663,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * The asynchronous version of the `take` action, which returns a * future for retrieving the first `num` elements of this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeAsync(num: Int): JavaFutureAction[JList[T]] = { new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 33f2f0b44f773..61905a8421124 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -726,6 +726,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only * one value per key is preserved in the map returned) + * + * @note this method should only be used if the resulting data is expected to be small, as + * all the data is loaded into the driver's memory. */ def collectAsMap(): Map[K, V] = self.withScope { val data = self.collect() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e8157cf4ebe7d..a81a98b526b5a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -481,6 +481,9 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator @@ -836,6 +839,9 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) @@ -1202,6 +1208,9 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @note due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ @@ -1263,6 +1272,9 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param num k, the number of top elements to return * @param ord the implicit ordering for T * @return an array of top elements @@ -1283,6 +1295,9 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c28594625457a..fe2264a63cf30 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -426,6 +426,9 @@ def takeSample(self, withReplacement, num, seed=None): """ Return a fixed-size sampled subset of this RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) 20 @@ -766,6 +769,8 @@ def func(it): def collect(self): """ Return a list that contains all of the elements in this RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) @@ -1213,6 +1218,9 @@ def top(self, num, key=None): """ Get the top N elements from a RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) @@ -1235,6 +1243,9 @@ def takeOrdered(self, num, key=None): Get the N elements from a RDD ordered in ascending order or as specified by the optional key function. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x) @@ -1254,6 +1265,9 @@ def take(self, num): that partition to estimate the number of additional partitions needed to satisfy the limit. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + Translated from the Scala implementation in RDD#take(). >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) @@ -1511,6 +1525,9 @@ def collectAsMap(self): """ Return the key-value pairs in this RDD to the master as a dictionary. + Note that this method should only be used if the resulting data is expected + to be small, as all the data is loaded into the driver's memory. + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() >>> m[1] 2 diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 90a6b5d9c0dda..3a8c8305ee3d8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -739,6 +739,9 @@ def describe(self, *cols): def head(self, n=None): """Returns the first ``n`` rows. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + :param n: int, default 1. Number of rows to return. :return: If n is greater than 1, return a list of :class:`Row`. If n is 1, return a single Row. @@ -1330,6 +1333,9 @@ def toDF(self, *cols): def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + Note that this method should only be used if the resulting Pandas's DataFrame is expected + to be small, as all the data is loaded into the driver's memory. + This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f15b926bd27cf..7aa08fb63053b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1384,6 +1384,10 @@ class DataFrame private[sql]( /** * Returns the first `n` rows. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @group action * @since 1.3.0 */ From bc8890b357811612ba6c10d96374902b9e08134f Mon Sep 17 00:00:00 2001 From: Gary King Date: Sun, 7 Feb 2016 09:13:28 +0000 Subject: [PATCH 102/107] [SPARK-13132][MLLIB] cache standardization param value in LogisticRegression cache the value of the standardization Param in LogisticRegression, rather than re-fetching it from the ParamMap for every index and every optimization step in the quasi-newton optimizer also, fix Param#toString to cache the stringified representation, rather than re-interpolating it on every call, so any other implementations that have similar repeated access patterns will see a benefit. this change improves training times for one of my test sets from ~7m30s to ~4m30s Author: Gary King Closes #11027 from idigary/spark-13132-optimize-logistic-regression. --- .../apache/spark/ml/classification/LogisticRegression.scala | 3 ++- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 9b2340a1f16fc..ac0124513f283 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -332,12 +332,13 @@ class LogisticRegression @Since("1.2.0") ( val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { + val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { // Remove the L1 penalization on the intercept if (index == numFeatures) { 0.0 } else { - if ($(standardization)) { + if (standardizationParam) { regParamL1 } else { // If `standardization` is false, we still standardize the data diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index f48923d69974b..d7d6c0f5fa16e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -117,7 +117,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } - override final def toString: String = s"${parent}__$name" + private[this] val stringRepresentation = s"${parent}__$name" + + override final def toString: String = stringRepresentation override final def hashCode: Int = toString.## From 140ddef373680cb08a3948a883b172dc80814170 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Sun, 7 Feb 2016 12:52:00 +0000 Subject: [PATCH 103/107] [SPARK-10963][STREAMING][KAFKA] make KafkaCluster public Author: cody koeninger Closes #9007 from koeninger/SPARK-10963. --- .../spark/streaming/kafka/KafkaCluster.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index d7885d7cc1ae1..8a66621a3125c 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -29,15 +29,19 @@ import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, To import kafka.consumer.{ConsumerConfig, SimpleConsumer} import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi /** + * :: DeveloperApi :: * Convenience methods for interacting with a Kafka cluster. + * See + * A Guide To The Kafka Protocol for more details on individual api calls. * @param kafkaParams Kafka * configuration parameters. * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), * NOT zookeeper servers, specified in host1:port1,host2:port2 form */ -private[spark] +@DeveloperApi class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig} @@ -227,7 +231,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { // this 0 here indicates api version, in this case the original ZK backed api. private def defaultConsumerApiVersion: Short = 0 - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def getConsumerOffsets( groupId: String, topicAndPartitions: Set[TopicAndPartition] @@ -246,7 +250,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { } } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def getConsumerOffsetMetadata( groupId: String, topicAndPartitions: Set[TopicAndPartition] @@ -283,7 +287,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { Left(errs) } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def setConsumerOffsets( groupId: String, offsets: Map[TopicAndPartition, Long] @@ -301,7 +305,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { setConsumerOffsetMetadata(groupId, meta, consumerApiVersion) } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def setConsumerOffsetMetadata( groupId: String, metadata: Map[TopicAndPartition, OffsetAndMetadata] @@ -359,7 +363,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { } } -private[spark] +@DeveloperApi object KafkaCluster { type Err = ArrayBuffer[Throwable] @@ -371,7 +375,6 @@ object KafkaCluster { ) } - private[spark] case class LeaderOffset(host: String, port: Int, offset: Long) /** @@ -379,7 +382,6 @@ object KafkaCluster { * Simple consumers connect directly to brokers, but need many of the same configs. * This subclass won't warn about missing ZK params, or presence of broker params. */ - private[spark] class SimpleConsumerConfig private(brokers: String, originalProps: Properties) extends ConsumerConfig(originalProps) { val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => @@ -391,7 +393,6 @@ object KafkaCluster { } } - private[spark] object SimpleConsumerConfig { /** * Make a consumer config without requiring group.id or zookeeper.connect, From edf4a0e62e6fdb849cca4f23a7060da5ec782b07 Mon Sep 17 00:00:00 2001 From: Nam Pham Date: Mon, 8 Feb 2016 11:06:41 -0800 Subject: [PATCH 104/107] [SPARK-12986][DOC] Fix pydoc warnings in mllib/regression.py I have fixed the warnings by running "make html" under "python/docs/". They are caused by not having blank lines around indented paragraphs. Author: Nam Pham Closes #11025 from nampham2/SPARK-12986. --- python/pyspark/mllib/regression.py | 34 ++++++++++++++++++------------ 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 13b3397501c0b..4dd7083d79c8c 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -219,8 +219,10 @@ class LinearRegressionWithSGD(object): """ Train a linear regression model with no regularization using Stochastic Gradient Descent. This solves the least squares regression formulation - f(weights) = 1/n ||A weights-y||^2^ - (which is the mean squared error). + + f(weights) = 1/n ||A weights-y||^2 + + which is the mean squared error. Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -367,8 +369,10 @@ def load(cls, sc, path): class LassoWithSGD(object): """ Train a regression model with L1-regularization using Stochastic Gradient Descent. - This solves the l1-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + This solves the L1-regularized least squares regression formulation + + f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -505,8 +509,10 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): """ Train a regression model with L2-regularization using Stochastic Gradient Descent. - This solves the l2-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + This solves the L2-regularized least squares regression formulation + + f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -655,17 +661,19 @@ class IsotonicRegression(object): Only univariate (single feature) algorithm supported. Sequential PAV implementation based on: - Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + + Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. - Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] + Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf Sequential PAV parallelization based on: - Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. - "An approach to parallelizing isotonic regression." - Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. - Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] - @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] + Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + "An approach to parallelizing isotonic regression." + Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. + Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf + + See `Isotonic regression (Wikipedia) `_. .. versionadded:: 1.4.0 """ From 06f0df6df204c4722ff8a6bf909abaa32a715c41 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Feb 2016 11:38:21 -0800 Subject: [PATCH 105/107] [SPARK-8964] [SQL] Use Exchange to perform shuffle in Limit This patch changes the implementation of the physical `Limit` operator so that it relies on the `Exchange` operator to perform data movement rather than directly using `ShuffledRDD`. In addition to improving efficiency, this lays the necessary groundwork for further optimization of limit, such as limit pushdown or whole-stage codegen. At a high-level, this replaces the old physical `Limit` operator with two new operators, `LocalLimit` and `GlobalLimit`. `LocalLimit` performs per-partition limits, while `GlobalLimit` applies the final limit to a single partition; `GlobalLimit`'s declares that its `requiredInputDistribution` is `SinglePartition`, which will cause the planner to use an `Exchange` to perform the appropriate shuffles. Thus, a logical `Limit` appearing in the middle of a query plan will be expanded into `LocalLimit -> Exchange to one partition -> GlobalLimit`. In the old code, calling `someDataFrame.limit(100).collect()` or `someDataFrame.take(100)` would actually skip the shuffle and use a fast-path which used `executeTake()` in order to avoid computing all partitions in case only a small number of rows were requested. This patch preserves this optimization by treating logical `Limit` operators specially when they appear as the terminal operator in a query plan: if a `Limit` is the final operator, then we will plan a special `CollectLimit` physical operator which implements the old `take()`-based logic. In order to be able to match on operators only at the root of the query plan, this patch introduces a special `ReturnAnswer` logical operator which functions similar to `BroadcastHint`: this dummy operator is inserted at the root of the optimized logical plan before invoking the physical planner, allowing the planner to pattern-match on it. Author: Josh Rosen Closes #7334 from JoshRosen/remove-copy-in-limit. --- .../plans/logical/basicOperators.scala | 11 ++ .../apache/spark/sql/execution/Exchange.scala | 130 ++++++++++-------- .../spark/sql/execution/QueryExecution.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 7 +- .../spark/sql/execution/basicOperators.scala | 95 +------------ .../apache/spark/sql/execution/limit.scala | 122 ++++++++++++++++ .../spark/sql/execution/PlannerSuite.scala | 10 +- .../spark/sql/execution/SortSuite.scala | 4 +- 8 files changed, 223 insertions(+), 160 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 03a79520cbd3a..57575f9ee09ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -25,6 +25,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +/** + * When planning take() or collect() operations, this special node that is inserted at the top of + * the logical plan before invoking the query planner. + * + * Rules can pattern-match on this node in order to apply transformations that only take effect + * at the top of the logical query plan. + */ +case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3770883af1e2f..97f65f18bfdcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -57,6 +57,69 @@ case class Exchange( override def output: Seq[Attribute] = child.output + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + + override protected def doPrepare(): Unit = { + // If an ExchangeCoordinator is needed, we register this Exchange operator + // to the coordinator when we do prepare. It is important to make sure + // we register this operator right before the execution instead of register it + // in the constructor because it is possible that we create new instances of + // Exchange operators when we transform the physical plan + // (then the ExchangeCoordinator will hold references of unneeded Exchanges). + // So, we should only call registerExchange just before we start to execute + // the plan. + coordinator match { + case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) + case None => + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { + Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer) + } + + /** + * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. + * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional + * partition start indices array. If this optional array is defined, the returned + * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. + */ + private[sql] def preparePostShuffleRDD( + shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { + // If an array of partition start indices is provided, we need to use this array + // to create the ShuffledRowRDD. Also, we need to update newPartitioning to + // update the number of post-shuffle partitions. + specifiedPartitionStartIndices.foreach { indices => + assert(newPartitioning.isInstanceOf[HashPartitioning]) + newPartitioning = UnknownPartitioning(indices.length) + } + new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + coordinator match { + case Some(exchangeCoordinator) => + val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) + assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) + shuffleRDD + case None => + val shuffleDependency = prepareShuffleDependency() + preparePostShuffleRDD(shuffleDependency) + } + } +} + +object Exchange { + def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { + Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) + } + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -82,7 +145,7 @@ case class Exchange( // passed instead of directly passing the number of partitions in order to guard against // corner-cases where a partitioner constructed with `numPartitions` partitions may output // fewer partitions (like RangePartitioner, for example). - val conf = child.sqlContext.sparkContext.conf + val conf = SparkEnv.get.conf val shuffleManager = SparkEnv.get.shuffleManager val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) @@ -117,30 +180,16 @@ case class Exchange( } } - private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - - override protected def doPrepare(): Unit = { - // If an ExchangeCoordinator is needed, we register this Exchange operator - // to the coordinator when we do prepare. It is important to make sure - // we register this operator right before the execution instead of register it - // in the constructor because it is possible that we create new instances of - // Exchange operators when we transform the physical plan - // (then the ExchangeCoordinator will hold references of unneeded Exchanges). - // So, we should only call registerExchange just before we start to execute - // the plan. - coordinator match { - case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) - case None => - } - } - /** * Returns a [[ShuffleDependency]] that will partition rows of its child based on * the partitioning scheme defined in `newPartitioning`. Those partitions of * the returned ShuffleDependency will be the input of shuffle. */ - private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { - val rdd = child.execute() + private[sql] def prepareShuffleDependency( + rdd: RDD[InternalRow], + outputAttributes: Seq[Attribute], + newPartitioning: Partitioning, + serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) case HashPartitioning(_, n) => @@ -160,7 +209,7 @@ case class Exchange( // We need to use an interpreted ordering here because generated orderings cannot be // serialized and this ordering needs to be created on the driver in order to be passed into // Spark core code. - implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output) + implicit val ordering = new InterpretedOrdering(sortingExpressions, outputAttributes) new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => new Partitioner { @@ -180,7 +229,7 @@ case class Exchange( position } case h: HashPartitioning => - val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output) + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") @@ -211,43 +260,6 @@ case class Exchange( dependency } - - /** - * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. - * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional - * partition start indices array. If this optional array is defined, the returned - * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. - */ - private[sql] def preparePostShuffleRDD( - shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], - specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { - // If an array of partition start indices is provided, we need to use this array - // to create the ShuffledRowRDD. Also, we need to update newPartitioning to - // update the number of post-shuffle partitions. - specifiedPartitionStartIndices.foreach { indices => - assert(newPartitioning.isInstanceOf[HashPartitioning]) - newPartitioning = UnknownPartitioning(indices.length) - } - new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - coordinator match { - case Some(exchangeCoordinator) => - val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) - assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) - shuffleRDD - case None => - val shuffleDependency = prepareShuffleDependency() - preparePostShuffleRDD(shuffleDependency) - } - } -} - -object Exchange { - def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { - Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 107570f9dbcc8..8616fe317034f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -44,7 +44,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.planner.plan(optimizedPlan).next() + sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 830bb011beab4..ee392e4e8debe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -338,8 +338,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil + case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) => + execution.CollectLimit(limit, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child)) :: Nil + val perPartitionLimit = execution.LocalLimit(limit, planLater(child)) + val globalLimit = execution.GlobalLimit(limit, perPartitionLimit) + globalLimit :: Nil case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => @@ -358,6 +362,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil + case logical.ReturnAnswer(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6e51c4d84824a..f63e8a9b6d79d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.execution -import org.apache.spark.{HashPartitioner, SparkEnv} -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} -import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType -import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler case class Project(projectList: Seq[NamedExpression], child: SparkPlan) @@ -306,96 +303,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { sparkContext.union(children.map(_.execute())) } -/** - * Take the first limit elements. Note that the implementation is different depending on whether - * this is a terminal operator or not. If it is terminal and is invoked using executeCollect, - * this operator uses something similar to Spark's take method on the Spark driver. If it is not - * terminal or is invoked using execute, we first take the limit on each partition, and then - * repartition all the data to a single partition to compute the global limit. - */ -case class Limit(limit: Int, child: SparkPlan) - extends UnaryNode { - // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: - // partition local limit -> exchange into one partition -> partition local limit again - - /** We must copy rows when sort based shuffle is on */ - private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = SinglePartition - - override def executeCollect(): Array[InternalRow] = child.executeTake(limit) - - protected override def doExecute(): RDD[InternalRow] = { - val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) { - child.execute().mapPartitionsInternal { iter => - iter.take(limit).map(row => (false, row.copy())) - } - } else { - child.execute().mapPartitionsInternal { iter => - val mutablePair = new MutablePair[Boolean, InternalRow]() - iter.take(limit).map(row => mutablePair.update(false, row)) - } - } - val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) - shuffled.mapPartitionsInternal(_.take(limit).map(_._2)) - } -} - -/** - * Take the first limit elements as defined by the sortOrder, and do projection if needed. - * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, - * or having a [[Project]] operator between them. - * This could have been named TopK, but Spark's top operator does the opposite in ordering - * so we name it TakeOrdered to avoid confusion. - */ -case class TakeOrderedAndProject( - limit: Int, - sortOrder: Seq[SortOrder], - projectList: Option[Seq[NamedExpression]], - child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) - } - - override def outputPartitioning: Partitioning = SinglePartition - - // We need to use an interpreted ordering here because generated orderings cannot be serialized - // and this ordering needs to be created on the driver in order to be passed into Spark core code. - private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - - private def collectData(): Array[InternalRow] = { - val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) - data.map(r => proj(r).copy()) - } else { - data - } - } - - override def executeCollect(): Array[InternalRow] = { - collectData() - } - - // TODO: Terminal split should be implemented differently from non-terminal split. - // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def simpleString: String = { - val orderByString = sortOrder.mkString("[", ",", "]") - val outputString = output.mkString("[", ",", "]") - - s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" - } -} - /** * Return a new RDD that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala new file mode 100644 index 0000000000000..256f4228ae99e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ + + +/** + * Take the first `limit` elements and collect them to a single partition. + * + * This operator will be used when a logical `Limit` operation is the final operator in an + * logical plan, which happens when the user is collecting results back to the driver. + */ +case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = SinglePartition + override def executeCollect(): Array[InternalRow] = child.executeTake(limit) + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + protected override def doExecute(): RDD[InternalRow] = { + val shuffled = new ShuffledRowRDD( + Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer)) + shuffled.mapPartitionsInternal(_.take(limit)) + } +} + +/** + * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. + */ +trait BaseLimit extends UnaryNode { + val limit: Int + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + iter.take(limit) + } +} + +/** + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + */ +case class LocalLimit(limit: Int, child: SparkPlan) extends BaseLimit { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + +/** + * Take the first `limit` elements of the child's single output partition. + */ +case class GlobalLimit(limit: Int, child: SparkPlan) extends BaseLimit { + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil +} + +/** + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a Limit operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. + */ +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + + override def outputPartitioning: Partitioning = SinglePartition + + // We need to use an interpreted ordering here because generated orderings cannot be serialized + // and this ordering needs to be created on the driver in order to be passed into Spark core code. + private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) + + private def collectData(): Array[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + data.map(r => proj(r).copy()) + } else { + data + } + } + + override def executeCollect(): Array[InternalRow] = { + collectData() + } + + // TODO: Terminal split should be implemented differently from non-terminal split. + // TODO: Pick num splits based on |limit|. + protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def simpleString: String = { + val orderByString = sortOrder.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + + s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index adaeb513bc1b8..a64ad4038c7c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -181,6 +181,12 @@ class PlannerSuite extends SharedSQLContext { } } + test("terminal limits use CollectLimit") { + val query = testData.select('value).limit(2) + val planned = query.queryExecution.sparkPlan + assert(planned.isInstanceOf[CollectLimit]) + } + test("PartitioningCollection") { withTempTable("normal", "small", "tiny") { testData.registerTempTable("normal") @@ -200,7 +206,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: Exchange => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } { @@ -215,7 +221,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: Exchange => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 6259453da26a1..cb6d68dc3ac46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -56,8 +56,8 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)), - (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)), + (child: SparkPlan) => GlobalLimit(10, Sort('a.asc :: Nil, global = true, child = child)), + (child: SparkPlan) => GlobalLimit(10, ReferenceSort('a.asc :: Nil, global = true, child)), sortAnswers = false ) } From 8e4d15f70713e1aaaa96dfb3ea4ccc5bb08eb2ce Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 8 Feb 2016 12:06:00 -0800 Subject: [PATCH 106/107] [SPARK-13101][SQL] nullability of array type element should not fail analysis of encoder nullability should only be considered as an optimization rather than part of the type system, so instead of failing analysis for mismatch nullability, we should pass analysis and add runtime null check. Author: Wenchen Fan Closes #11035 from cloud-fan/ignore-nullability. --- .../sql/catalyst/JavaTypeInference.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 29 +++-- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/expressions/objects.scala | 20 ++-- .../encoders/EncoderResolutionSuite.scala | 107 +++++------------- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 4 +- 7 files changed, 64 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 3c3717d5043aa..59ee41d02f198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -292,7 +292,7 @@ object JavaTypeInference { val setter = if (nullable) { constructor } else { - AssertNotNull(constructor, other.getName, fieldName, fieldType.toString) + AssertNotNull(constructor, Seq("currently no type path record in java")) } p.getWriteMethod.getName -> setter }.toMap diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e5811efb436a6..02cb2d9a2b118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t + + // TODO: add runtime null check for primitive array val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val arrayData = - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p), newTypePath), - getPath, - schemaFor(elementType).dataType), - "array", - ObjectType(classOf[Array[Any]])) + + val mapFunction: Expression => Expression = p => { + val converter = constructorFor(elementType, Some(p), newTypePath) + if (nullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } + } + + val array = Invoke( + MapObjects(mapFunction, getPath, dataType), + "array", + ObjectType(classOf[Array[Any]])) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", - arrayData :: Nil) + array :: Nil) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map @@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection { newTypePath) if (!nullable) { - AssertNotNull(constructor, t.toString, fieldName, fieldType.toString) + AssertNotNull(constructor, newTypePath) } else { constructor } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cb228cf52b433..4d53b232d5510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1426,7 +1426,7 @@ object ResolveUpCast extends Rule[LogicalPlan] { fail(child, DateType, walkedTypePath) case (StringType, to: NumericType) => fail(child, to, walkedTypePath) - case _ => Cast(child, dataType) + case _ => Cast(child, dataType.asNullable) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 79fe0033b71ab..fef6825b2db5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -365,7 +365,7 @@ object MapObjects { * to handle collection elements. * @param inputData An expression that when evaluted returns a collection object. */ -case class MapObjects( +case class MapObjects private( loopVar: LambdaVariable, lambdaFunction: Expression, inputData: Expression) extends Expression { @@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull( - child: Expression, parentType: String, fieldName: String, fieldType: String) +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) extends UnaryExpression { override def dataType: DataType = child.dataType @@ -651,6 +650,14 @@ case class AssertNotNull( override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val childGen = child.gen(ctx) + val errMsg = "Null value appeared in non-nullable field:" + + walkedTypePath.mkString("\n", "\n", "\n") + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + val idx = ctx.references.length + ctx.references += errMsg + ev.isNull = "false" ev.value = childGen.value @@ -658,12 +665,7 @@ case class AssertNotNull( ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException( - "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " + - "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + - "please try to use scala.Option[_] or other nullable types " + - "(e.g. java.lang.Integer instead of int/scala.Int)." - ); + throw new RuntimeException((String) references[$idx]); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 92a68a4dba915..8b02b63c6cf3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class StringLongClass(a: String, b: Long) @@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) class EncoderResolutionSuite extends PlanTest { + private val str = UTF8String.fromString("hello") + test("real type doesn't match encoder schema but they are compatible: product") { val encoder = ExpressionEncoder[StringLongClass] - val cls = classOf[StringLongClass] - - { - val attrs = Seq('a.string, 'b.int) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - cls, - Seq( - toExternalString('a.string), - AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") - ), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) - } + // int type can be up cast to long type + val attrs1 = Seq('a.string, 'b.int) + encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1)) - { - val attrs = Seq('a.int, 'b.long) - val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression - val expected = NewInstance( - cls, - Seq( - toExternalString('a.int.cast(StringType)), - AssertNotNull('b.long, cls.getName, "b", "Long") - ), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) - } + // int type can be up cast to string type + val attrs2 = Seq('a.int, 'b.long) + encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] - val innerCls = classOf[StringLongClass] - val cls = classOf[ComplexClass] - val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - cls, - Seq( - AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"), - If( - 'b.struct('a.int, 'b.long).isNull, - Literal.create(null, ObjectType(innerCls)), - NewInstance( - innerCls, - Seq( - toExternalString( - GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), - AssertNotNull( - GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), - innerCls.getName, "b", "Long")), - ObjectType(innerCls), - propagateNull = false) - )), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) + encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { val encoder = ExpressionEncoder.tuple( ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) - val cls = classOf[StringLongClass] - val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - classOf[Tuple2[_, _]], - Seq( - NewInstance( - cls, - Seq( - toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), - AssertNotNull( - GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), - cls.getName, "b", "Long")), - ObjectType(cls), - propagateNull = false), - 'b.int.cast(LongType)), - ObjectType(classOf[Tuple2[_, _]]), - propagateNull = false) - compareExpressions(fromRowExpr, expected) + encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + } + + test("nullability of array type element should not fail analysis") { + val encoder = ExpressionEncoder[Seq[Int]] + val attrs = 'a.array(IntegerType) :: Nil + + // It should pass analysis + val bound = encoder.resolve(attrs, null).bind(attrs) + + // If no null values appear, it should works fine + bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) + + // If there is null value, it should throw runtime exception + val e = intercept[RuntimeException] { + bound.fromRow(InternalRow(new GenericArrayData(Array(1, null)))) + } + assert(e.getMessage.contains("Null value appeared in non-nullable field")) } test("the real number of fields doesn't match encoder schema: tuple encoder") { @@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest { } } - private def toExternalString(e: Expression): Expression = { - Invoke(e, "toString", ObjectType(classOf[String]), Nil) - } - test("throw exception if real type is not compatible with encoder schema") { val msg1 = intercept[AnalysisException] { ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a6fb62c17d59b..1181244c8a4ed 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -850,9 +850,7 @@ public void testRuntimeNullabilityCheck() { } nullabilityCheck.expect(RuntimeException.class); - nullabilityCheck.expectMessage( - "Null value appeared in non-nullable field " + - "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int."); + nullabilityCheck.expectMessage("Null value appeared in non-nullable field"); { Row row = new GenericRow(new Object[] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 374f4320a9239..f9ba60770022d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -553,9 +553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { buildDataset(Row(Row("hello", null))).collect() }.getMessage - assert(message.contains( - "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." - )) + assert(message.contains("Null value appeared in non-nullable field")) } test("SPARK-12478: top level null field") { From 37bc203c8dd5022cb11d53b697c28a737ee85bcc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 8 Feb 2016 12:08:58 -0800 Subject: [PATCH 107/107] [SPARK-13210][SQL] catch OOM when allocate memory and expand array There is a bug when we try to grow the buffer, OOM is ignore wrongly (the assert also skipped by JVM), then we try grow the array again, this one will trigger spilling free the current page, the current record we inserted will be invalid. The root cause is that JVM has less free memory than MemoryManager thought, it will OOM when allocate a page without trigger spilling. We should catch the OOM, and acquire memory again to trigger spilling. And also, we could not grow the array in `insertRecord` of `InMemorySorter` (it was there just for easy testing). Author: Davies Liu Closes #11095 from davies/fix_expand. --- .../spark/memory/TaskMemoryManager.java | 23 ++++++++++++++++++- .../shuffle/sort/ShuffleExternalSorter.java | 10 +------- .../shuffle/sort/ShuffleInMemorySorter.java | 2 +- .../unsafe/sort/UnsafeExternalSorter.java | 10 +------- .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../sort/ShuffleInMemorySorterSuite.java | 6 +++++ .../sort/UnsafeInMemorySorterSuite.java | 3 +++ 7 files changed, 35 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d2a88864f7ac9..8757dff36f159 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -111,6 +111,11 @@ public class TaskMemoryManager { @GuardedBy("this") private final HashSet consumers; + /** + * The amount of memory that is acquired but not used. + */ + private long acquiredButNotUsed = 0L; + /** * Construct a new TaskMemoryManager. */ @@ -256,7 +261,20 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { } allocatedPages.set(pageNumber); } - final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired); + MemoryBlock page = null; + try { + page = memoryManager.tungstenMemoryAllocator().allocate(acquired); + } catch (OutOfMemoryError e) { + logger.warn("Failed to allocate a page ({} bytes), try again.", acquired); + // there is no enough memory actually, it means the actual free memory is smaller than + // MemoryManager thought, we should keep the acquired memory. + acquiredButNotUsed += acquired; + synchronized (this) { + allocatedPages.clear(pageNumber); + } + // this could trigger spilling to free some pages. + return allocatePage(size, consumer); + } page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { @@ -378,6 +396,9 @@ public long cleanUpAllAllocatedMemory() { } Arrays.fill(pageTable, null); + // release the memory that is not used by any consumer. + memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode); + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 2c84de5bf2a5a..f97e76d7ed0d9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -320,15 +320,7 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - LongArray array; - try { - // could trigger spilling - array = allocateArray(used / 8 * 2); - } catch (OutOfMemoryError e) { - // should have trigger spilling - assert(inMemSorter.hasSpaceForAnotherRecord()); - return; - } + LongArray array = allocateArray(used / 8 * 2); // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { freeArray(array); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 58ad88e1ed87b..d74602cd205ad 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -104,7 +104,7 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - expandPointerArray(consumer.allocateArray(array.size() * 2)); + throw new IllegalStateException("There is no space for new record"); } array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId)); pos++; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index a6edc1ad3f665..296bf722fc178 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -293,15 +293,7 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - LongArray array; - try { - // could trigger spilling - array = allocateArray(used / 8 * 2); - } catch (OutOfMemoryError e) { - // should have trigger spilling - assert(inMemSorter.hasSpaceForAnotherRecord()); - return; - } + LongArray array = allocateArray(used / 8 * 2); // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { freeArray(array); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index d1b0bc5d11f46..cea0f0a0c6c11 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -164,7 +164,7 @@ public void expandPointerArray(LongArray newArray) { */ public void insertRecord(long recordPointer, long keyPrefix) { if (!hasSpaceForAnotherRecord()) { - expandPointerArray(consumer.allocateArray(array.size() * 2)); + throw new IllegalStateException("There is no space for new record"); } array.set(pos, recordPointer); pos++; diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 0328e63e45439..eb1da8e1b43eb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -75,6 +75,9 @@ public void testBasicSorting() throws Exception { // Write the records into the data page and store pointers into the sorter long position = dataPage.getBaseOffset(); for (String str : dataToSort) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2)); + } final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes("utf-8"); Platform.putInt(baseObject, position, strBytes.length); @@ -114,6 +117,9 @@ public void testSortingManyNumbers() throws Exception { int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2)); + } numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); sorter.insertRecord(0, numbersToSort[i]); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 93efd033eb940..8e557ec0ab0b4 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -111,6 +111,9 @@ public int compare(long prefix1, long prefix2) { // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2 * 2)); + } // position now points to the start of a record (which holds its length). final int recordLength = Platform.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);