diff --git a/core/src/main/scala/org/apache/spark/util/CastedArray.scala b/core/src/main/scala/org/apache/spark/util/CastedArray.scala new file mode 100644 index 0000000000000..82435fb784436 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CastedArray.scala @@ -0,0 +1,126 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.util + +/** + * Provides a wrapper around an object that is known to be an array, but the specific + * type for the array is unknown. + * + * Normally, in situations when such an array is to be accessed reflectively, one would use + * {@link java.lang.reflect.Array} using getLength() and get() methods. However, it turns + * out that such methods are ill-performant. + * + * It turns out it is better to just use instanceOf and lots of casting over calling through + * to the native C implementation. There is some discussion and a sample code snippet in + * an open JDK ticket. In this + * class, that approach is implemented in an alternative way: creating a wrapper object to + * wrap the array allows the cast to be done once, so the overhead of casting multiple times + * is also avoided. It turns out we invoke the get() method to get the value of the array + * numerous times, so doing the cast just once is worth the cost of constructing the wrapper + * object for larger arrays. + */ +sealed trait CastedArray extends Any { + def get(i: Int): AnyRef + def getLength(): Int +} + +object CastedArray { + def castAndWrap(arr: AnyRef): CastedArray = { + if (arr.isInstanceOf[Array[Boolean]]) { + return new BooleanCastedArray(arr.asInstanceOf[Array[Boolean]]) + } else if (arr.isInstanceOf[Array[Byte]]) { + return new ByteCastedArray(arr.asInstanceOf[Array[Byte]]) + } else if (arr.isInstanceOf[Array[Char]]) { + return new CharCastedArray(arr.asInstanceOf[Array[Char]]) + } else if (arr.isInstanceOf[Array[Double]]) { + return new DoubleCastedArray(arr.asInstanceOf[Array[Double]]) + } else if (arr.isInstanceOf[Array[Float]]) { + return new FloatCastedArray(arr.asInstanceOf[Array[Float]]) + } else if (arr.isInstanceOf[Array[Int]]) { + return new IntCastedArray(arr.asInstanceOf[Array[Int]]) + } else if (arr.isInstanceOf[Array[Long]]) { + return new LongCastedArray(arr.asInstanceOf[Array[Long]]) + } else if (arr.isInstanceOf[Array[Object]]) { + return new ObjectCastedArray(arr.asInstanceOf[Array[Object]]) + } else if (arr.isInstanceOf[Array[Short]]) { + return new ShortCastedArray(arr.asInstanceOf[Array[Short]]) + } else { + throw createBadArrayException(arr) + } + } + + // Boxing is not ideal, but we want to return AnyRef here. An alternative implementation + // that used Java wouldn't force explicitly boxing... but returning Object there would + // make the boxing happen implicitly anyways. In practice this tends to be okay + // in terms of performance. + private class BooleanCastedArray(val arr: Array[Boolean]) extends AnyVal with CastedArray { + override def get(i: Int): AnyRef = Boolean.box(arr(i)) + override def getLength(): Int = arr.length + } + + private class ByteCastedArray(val arr: Array[Byte]) extends AnyVal with CastedArray { + override def get(i: Int): AnyRef = Byte.box(arr(i)) + override def getLength(): Int = arr.length + } + + private class CharCastedArray(val arr: Array[Char]) extends AnyVal with CastedArray { + override def get(i: Int): AnyRef = Char.box(arr(i)) + override def getLength(): Int = arr.length + } + + private class DoubleCastedArray(val arr: Array[Double]) extends AnyVal with CastedArray { + override def get(i: Int): AnyRef = Double.box(arr(i)) + override def getLength(): Int = arr.length + } + + private class FloatCastedArray(val arr: Array[Float]) extends AnyVal with CastedArray { + override def get(i: Int): AnyRef = Float.box(arr(i)) + override def getLength(): Int = arr.length + } + + private class IntCastedArray(val arr: Array[Int]) extends AnyVal with CastedArray { + override def get(i: Int): AnyRef = Int.box(arr(i)) + override def getLength(): Int = arr.length + } + + private class LongCastedArray(val arr: Array[Long]) extends AnyVal with CastedArray { + override def get(i: Int): AnyRef = Long.box(arr(i)) + override def getLength(): Int = arr.length + } + + private class ObjectCastedArray(val arr: Array[Object]) extends AnyVal with CastedArray { + override def get(i: Int): Object = arr(i) + override def getLength(): Int = arr.length + } + + private class ShortCastedArray(val arr: Array[Short]) extends AnyVal with CastedArray { + override def get(i: Int): AnyRef = Short.box(arr(i)) + override def getLength(): Int = arr.length + } + + private def createBadArrayException(badArray : Object): RuntimeException = { + if (badArray == null) { + return new NullPointerException("Array argument is null"); + } else if (!badArray.getClass().isArray()) { + return new IllegalArgumentException("Argument is not an array"); + } else { + return new IllegalArgumentException("Array is of incompatible type"); + } + } +} + diff --git a/core/src/main/scala/org/apache/spark/util/ReflectArrayGetter.java b/core/src/main/scala/org/apache/spark/util/ReflectArrayGetter.java deleted file mode 100644 index b3a04b6cb03c7..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ReflectArrayGetter.java +++ /dev/null @@ -1,247 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.util; - -/** - * Provides a wrapper around an object that is known to be an array, but the specific - * type for the array is unknown. - * - * Normally, in situations when such an array is to be accessed reflectively, one would use - * {@link java.lang.reflect.Array} using getLength() and get() methods. However, it turns - * out that such methods are ill-performant. - * - * It turns out it is better to just use instanceOf and lots of casting over calling through - * to the native C implementation. There is some discussion and a sample code snippet in - * an open JDK ticket. In this - * class, that approach is implemented in an alternative way: creating a wrapper object to - * wrap the array allows the cast to be done once, so the overhead of casting multiple times - * is also avoided. It turns out we invoke the get() method to get the value of the array - * numerous times, so doing the cast just once is worth the cost of constructing the wrapper - * object for larger arrays. - */ -public abstract class ReflectArrayGetter { - - public static ReflectArrayGetter create(Object array) { - if (array instanceof Object[]) { - Object[] casted = (Object[]) array; - return new ObjectArrayGetter(casted); - } else if (array instanceof boolean[]) { - boolean[] casted = (boolean[]) array; - return new BooleanArrayGetter(casted); - } else if (array instanceof int[]) { - int[] casted = (int[]) array; - return new IntArrayGetter(casted); - } else if (array instanceof byte[]) { - byte[] casted = (byte[]) array; - return new ByteArrayGetter(casted); - } else if (array instanceof short[]) { - short[] casted = (short[]) array; - return new ShortArrayGetter(casted); - } else if (array instanceof char[]) { - char[] casted = (char[]) array; - return new CharArrayGetter(casted); - } else if (array instanceof long[]) { - long[] casted = (long[]) array; - return new LongArrayGetter(casted); - } else if (array instanceof double[]) { - double[] casted = (double[]) array; - return new DoubleArrayGetter(casted); - } else if (array instanceof float[]) { - float[] casted = (float[]) array; - return new FloatArrayGetter(casted); - } else { - throw badArray(array); - } - } - - public abstract int getLength(); - - public abstract Object get(Integer i); - - private static RuntimeException badArray(Object array) { - if (array == null) { - return new NullPointerException("Array argument is null"); - } else if (!array.getClass().isArray()) { - return new IllegalArgumentException("Argument is not an array"); - } else { - return new IllegalArgumentException("Array is of incompatible type"); - } - } - - private static class BooleanArrayGetter extends ReflectArrayGetter { - private boolean[] arr; - - public BooleanArrayGetter(boolean[] arr) { - this.arr = arr; - } - - @Override - public Boolean get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } - - private static class ByteArrayGetter extends ReflectArrayGetter { - private byte[] arr; - - public ByteArrayGetter(byte[] arr) { - this.arr = arr; - } - - @Override - public Byte get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } - - private static class CharArrayGetter extends ReflectArrayGetter { - private char[] arr; - - public CharArrayGetter(char[] arr) { - this.arr = arr; - } - - @Override - public Character get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } - - private static class IntArrayGetter extends ReflectArrayGetter { - private int[] arr; - - public IntArrayGetter(int[] arr) { - this.arr = arr; - } - - @Override - public Integer get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } - - private static class ShortArrayGetter extends ReflectArrayGetter { - private short[] arr; - - public ShortArrayGetter(short[] arr) { - this.arr = arr; - } - - @Override - public Short get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } - - private static class LongArrayGetter extends ReflectArrayGetter { - private long[] arr; - - public LongArrayGetter(long[] arr) { - this.arr = arr; - } - - @Override - public Long get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } - - private static class FloatArrayGetter extends ReflectArrayGetter { - private float[] arr; - - public FloatArrayGetter(float[] arr) { - this.arr = arr; - } - - @Override - public Float get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } - - private static class DoubleArrayGetter extends ReflectArrayGetter { - private double[] arr; - - public DoubleArrayGetter(double[] arr) { - this.arr = arr; - } - - @Override - public Double get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } - - private static class ObjectArrayGetter extends ReflectArrayGetter { - private Object[] arr; - - public ObjectArrayGetter(Object[] arr) { - this.arr = arr; - } - - @Override - public Object get(Integer v1) { - return arr[v1]; - } - - @Override - public int getLength() { - return arr.length; - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 57619d7eabef6..e7151a8fd3b0e 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -184,8 +184,8 @@ private[spark] object SizeEstimator extends Logging { private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { - val reflectArrayGetter = ReflectArrayGetter.create(array) - val length = reflectArrayGetter.getLength + val castedArray: CastedArray = CastedArray.castAndWrap(array) + val length = castedArray.getLength val elementClass = cls.getComponentType // Arrays have object header and length field which is an integer @@ -200,7 +200,7 @@ private[spark] object SizeEstimator extends Logging { if (length <= ARRAY_SIZE_FOR_SAMPLING) { for (i <- 0 until length) { - state.enqueue(reflectArrayGetter.get(i)) + state.enqueue(castedArray.get(i)) } } else { // Estimate the size of a large array by sampling elements without replacement. @@ -213,7 +213,7 @@ private[spark] object SizeEstimator extends Logging { index = rand.nextInt(length) } while (drawn.contains(index)) drawn.add(index) - val elem = reflectArrayGetter.get(index) + val elem = castedArray.get(index) size += SizeEstimator.estimate(elem, state.visited) } state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong