diff --git a/core/src/main/scala/org/apache/spark/util/CastedArray.scala b/core/src/main/scala/org/apache/spark/util/CastedArray.scala
new file mode 100644
index 0000000000000..82435fb784436
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/CastedArray.scala
@@ -0,0 +1,126 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.util
+
+/**
+ * Provides a wrapper around an object that is known to be an array, but the specific
+ * type for the array is unknown.
+ *
+ * Normally, in situations when such an array is to be accessed reflectively, one would use
+ * {@link java.lang.reflect.Array} using getLength() and get() methods. However, it turns
+ * out that such methods are ill-performant.
+ *
+ * It turns out it is better to just use instanceOf and lots of casting over calling through
+ * to the native C implementation. There is some discussion and a sample code snippet in
+ * an open JDK ticket. In this
+ * class, that approach is implemented in an alternative way: creating a wrapper object to
+ * wrap the array allows the cast to be done once, so the overhead of casting multiple times
+ * is also avoided. It turns out we invoke the get() method to get the value of the array
+ * numerous times, so doing the cast just once is worth the cost of constructing the wrapper
+ * object for larger arrays.
+ */
+sealed trait CastedArray extends Any {
+ def get(i: Int): AnyRef
+ def getLength(): Int
+}
+
+object CastedArray {
+ def castAndWrap(arr: AnyRef): CastedArray = {
+ if (arr.isInstanceOf[Array[Boolean]]) {
+ return new BooleanCastedArray(arr.asInstanceOf[Array[Boolean]])
+ } else if (arr.isInstanceOf[Array[Byte]]) {
+ return new ByteCastedArray(arr.asInstanceOf[Array[Byte]])
+ } else if (arr.isInstanceOf[Array[Char]]) {
+ return new CharCastedArray(arr.asInstanceOf[Array[Char]])
+ } else if (arr.isInstanceOf[Array[Double]]) {
+ return new DoubleCastedArray(arr.asInstanceOf[Array[Double]])
+ } else if (arr.isInstanceOf[Array[Float]]) {
+ return new FloatCastedArray(arr.asInstanceOf[Array[Float]])
+ } else if (arr.isInstanceOf[Array[Int]]) {
+ return new IntCastedArray(arr.asInstanceOf[Array[Int]])
+ } else if (arr.isInstanceOf[Array[Long]]) {
+ return new LongCastedArray(arr.asInstanceOf[Array[Long]])
+ } else if (arr.isInstanceOf[Array[Object]]) {
+ return new ObjectCastedArray(arr.asInstanceOf[Array[Object]])
+ } else if (arr.isInstanceOf[Array[Short]]) {
+ return new ShortCastedArray(arr.asInstanceOf[Array[Short]])
+ } else {
+ throw createBadArrayException(arr)
+ }
+ }
+
+ // Boxing is not ideal, but we want to return AnyRef here. An alternative implementation
+ // that used Java wouldn't force explicitly boxing... but returning Object there would
+ // make the boxing happen implicitly anyways. In practice this tends to be okay
+ // in terms of performance.
+ private class BooleanCastedArray(val arr: Array[Boolean]) extends AnyVal with CastedArray {
+ override def get(i: Int): AnyRef = Boolean.box(arr(i))
+ override def getLength(): Int = arr.length
+ }
+
+ private class ByteCastedArray(val arr: Array[Byte]) extends AnyVal with CastedArray {
+ override def get(i: Int): AnyRef = Byte.box(arr(i))
+ override def getLength(): Int = arr.length
+ }
+
+ private class CharCastedArray(val arr: Array[Char]) extends AnyVal with CastedArray {
+ override def get(i: Int): AnyRef = Char.box(arr(i))
+ override def getLength(): Int = arr.length
+ }
+
+ private class DoubleCastedArray(val arr: Array[Double]) extends AnyVal with CastedArray {
+ override def get(i: Int): AnyRef = Double.box(arr(i))
+ override def getLength(): Int = arr.length
+ }
+
+ private class FloatCastedArray(val arr: Array[Float]) extends AnyVal with CastedArray {
+ override def get(i: Int): AnyRef = Float.box(arr(i))
+ override def getLength(): Int = arr.length
+ }
+
+ private class IntCastedArray(val arr: Array[Int]) extends AnyVal with CastedArray {
+ override def get(i: Int): AnyRef = Int.box(arr(i))
+ override def getLength(): Int = arr.length
+ }
+
+ private class LongCastedArray(val arr: Array[Long]) extends AnyVal with CastedArray {
+ override def get(i: Int): AnyRef = Long.box(arr(i))
+ override def getLength(): Int = arr.length
+ }
+
+ private class ObjectCastedArray(val arr: Array[Object]) extends AnyVal with CastedArray {
+ override def get(i: Int): Object = arr(i)
+ override def getLength(): Int = arr.length
+ }
+
+ private class ShortCastedArray(val arr: Array[Short]) extends AnyVal with CastedArray {
+ override def get(i: Int): AnyRef = Short.box(arr(i))
+ override def getLength(): Int = arr.length
+ }
+
+ private def createBadArrayException(badArray : Object): RuntimeException = {
+ if (badArray == null) {
+ return new NullPointerException("Array argument is null");
+ } else if (!badArray.getClass().isArray()) {
+ return new IllegalArgumentException("Argument is not an array");
+ } else {
+ return new IllegalArgumentException("Array is of incompatible type");
+ }
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/util/ReflectArrayGetter.java b/core/src/main/scala/org/apache/spark/util/ReflectArrayGetter.java
deleted file mode 100644
index b3a04b6cb03c7..0000000000000
--- a/core/src/main/scala/org/apache/spark/util/ReflectArrayGetter.java
+++ /dev/null
@@ -1,247 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.util;
-
-/**
- * Provides a wrapper around an object that is known to be an array, but the specific
- * type for the array is unknown.
- *
- * Normally, in situations when such an array is to be accessed reflectively, one would use
- * {@link java.lang.reflect.Array} using getLength() and get() methods. However, it turns
- * out that such methods are ill-performant.
- *
- * It turns out it is better to just use instanceOf and lots of casting over calling through
- * to the native C implementation. There is some discussion and a sample code snippet in
- * an open JDK ticket. In this
- * class, that approach is implemented in an alternative way: creating a wrapper object to
- * wrap the array allows the cast to be done once, so the overhead of casting multiple times
- * is also avoided. It turns out we invoke the get() method to get the value of the array
- * numerous times, so doing the cast just once is worth the cost of constructing the wrapper
- * object for larger arrays.
- */
-public abstract class ReflectArrayGetter {
-
- public static ReflectArrayGetter create(Object array) {
- if (array instanceof Object[]) {
- Object[] casted = (Object[]) array;
- return new ObjectArrayGetter(casted);
- } else if (array instanceof boolean[]) {
- boolean[] casted = (boolean[]) array;
- return new BooleanArrayGetter(casted);
- } else if (array instanceof int[]) {
- int[] casted = (int[]) array;
- return new IntArrayGetter(casted);
- } else if (array instanceof byte[]) {
- byte[] casted = (byte[]) array;
- return new ByteArrayGetter(casted);
- } else if (array instanceof short[]) {
- short[] casted = (short[]) array;
- return new ShortArrayGetter(casted);
- } else if (array instanceof char[]) {
- char[] casted = (char[]) array;
- return new CharArrayGetter(casted);
- } else if (array instanceof long[]) {
- long[] casted = (long[]) array;
- return new LongArrayGetter(casted);
- } else if (array instanceof double[]) {
- double[] casted = (double[]) array;
- return new DoubleArrayGetter(casted);
- } else if (array instanceof float[]) {
- float[] casted = (float[]) array;
- return new FloatArrayGetter(casted);
- } else {
- throw badArray(array);
- }
- }
-
- public abstract int getLength();
-
- public abstract Object get(Integer i);
-
- private static RuntimeException badArray(Object array) {
- if (array == null) {
- return new NullPointerException("Array argument is null");
- } else if (!array.getClass().isArray()) {
- return new IllegalArgumentException("Argument is not an array");
- } else {
- return new IllegalArgumentException("Array is of incompatible type");
- }
- }
-
- private static class BooleanArrayGetter extends ReflectArrayGetter {
- private boolean[] arr;
-
- public BooleanArrayGetter(boolean[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Boolean get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-
- private static class ByteArrayGetter extends ReflectArrayGetter {
- private byte[] arr;
-
- public ByteArrayGetter(byte[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Byte get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-
- private static class CharArrayGetter extends ReflectArrayGetter {
- private char[] arr;
-
- public CharArrayGetter(char[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Character get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-
- private static class IntArrayGetter extends ReflectArrayGetter {
- private int[] arr;
-
- public IntArrayGetter(int[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Integer get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-
- private static class ShortArrayGetter extends ReflectArrayGetter {
- private short[] arr;
-
- public ShortArrayGetter(short[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Short get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-
- private static class LongArrayGetter extends ReflectArrayGetter {
- private long[] arr;
-
- public LongArrayGetter(long[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Long get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-
- private static class FloatArrayGetter extends ReflectArrayGetter {
- private float[] arr;
-
- public FloatArrayGetter(float[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Float get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-
- private static class DoubleArrayGetter extends ReflectArrayGetter {
- private double[] arr;
-
- public DoubleArrayGetter(double[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Double get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-
- private static class ObjectArrayGetter extends ReflectArrayGetter {
- private Object[] arr;
-
- public ObjectArrayGetter(Object[] arr) {
- this.arr = arr;
- }
-
- @Override
- public Object get(Integer v1) {
- return arr[v1];
- }
-
- @Override
- public int getLength() {
- return arr.length;
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index 57619d7eabef6..e7151a8fd3b0e 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -184,8 +184,8 @@ private[spark] object SizeEstimator extends Logging {
private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING
private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) {
- val reflectArrayGetter = ReflectArrayGetter.create(array)
- val length = reflectArrayGetter.getLength
+ val castedArray: CastedArray = CastedArray.castAndWrap(array)
+ val length = castedArray.getLength
val elementClass = cls.getComponentType
// Arrays have object header and length field which is an integer
@@ -200,7 +200,7 @@ private[spark] object SizeEstimator extends Logging {
if (length <= ARRAY_SIZE_FOR_SAMPLING) {
for (i <- 0 until length) {
- state.enqueue(reflectArrayGetter.get(i))
+ state.enqueue(castedArray.get(i))
}
} else {
// Estimate the size of a large array by sampling elements without replacement.
@@ -213,7 +213,7 @@ private[spark] object SizeEstimator extends Logging {
index = rand.nextInt(length)
} while (drawn.contains(index))
drawn.add(index)
- val elem = reflectArrayGetter.get(index)
+ val elem = castedArray.get(index)
size += SizeEstimator.estimate(elem, state.visited)
}
state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong