From 64677599cb57f84e7b1d9d62eb6f97e7373e18a2 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 12 Mar 2015 10:34:32 -0700 Subject: [PATCH] Including primitive size information inside CastedArray. --- .../org/apache/spark/util/CastedArray.scala | 61 ++++++++++++------- .../apache/spark/util/PrimitiveSizes.scala | 32 ++++++++++ .../org/apache/spark/util/SizeEstimator.scala | 33 ++++------ 3 files changed, 83 insertions(+), 43 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/PrimitiveSizes.scala diff --git a/core/src/main/scala/org/apache/spark/util/CastedArray.scala b/core/src/main/scala/org/apache/spark/util/CastedArray.scala index 82435fb784436..8ccc02427984d 100644 --- a/core/src/main/scala/org/apache/spark/util/CastedArray.scala +++ b/core/src/main/scala/org/apache/spark/util/CastedArray.scala @@ -33,34 +33,34 @@ package org.apache.spark.util * is also avoided. It turns out we invoke the get() method to get the value of the array * numerous times, so doing the cast just once is worth the cost of constructing the wrapper * object for larger arrays. + * + * In general, these classes were designed to avoid the need to cast as much as possible. As + * soon as the type of the array is known, it is casted immediately once and all of its metadata + * (primitive type size, length, and whether or not it is a primitive array) is available + * immediately without any further reflection or introspecting on class objects. */ sealed trait CastedArray extends Any { def get(i: Int): AnyRef def getLength(): Int + def isPrimitiveArray(): Boolean + def getElementSize(): Int } object CastedArray { - def castAndWrap(arr: AnyRef): CastedArray = { - if (arr.isInstanceOf[Array[Boolean]]) { - return new BooleanCastedArray(arr.asInstanceOf[Array[Boolean]]) - } else if (arr.isInstanceOf[Array[Byte]]) { - return new ByteCastedArray(arr.asInstanceOf[Array[Byte]]) - } else if (arr.isInstanceOf[Array[Char]]) { - return new CharCastedArray(arr.asInstanceOf[Array[Char]]) - } else if (arr.isInstanceOf[Array[Double]]) { - return new DoubleCastedArray(arr.asInstanceOf[Array[Double]]) - } else if (arr.isInstanceOf[Array[Float]]) { - return new FloatCastedArray(arr.asInstanceOf[Array[Float]]) - } else if (arr.isInstanceOf[Array[Int]]) { - return new IntCastedArray(arr.asInstanceOf[Array[Int]]) - } else if (arr.isInstanceOf[Array[Long]]) { - return new LongCastedArray(arr.asInstanceOf[Array[Long]]) - } else if (arr.isInstanceOf[Array[Object]]) { - return new ObjectCastedArray(arr.asInstanceOf[Array[Object]]) - } else if (arr.isInstanceOf[Array[Short]]) { - return new ShortCastedArray(arr.asInstanceOf[Array[Short]]) - } else { - throw createBadArrayException(arr) + // Sizes of primitive types + + def castAndWrap(obj: AnyRef): CastedArray = { + obj match { + case arr: Array[Boolean] => new BooleanCastedArray(arr) + case arr: Array[Byte] => new ByteCastedArray(arr) + case arr: Array[Char] => new CharCastedArray(arr) + case arr: Array[Double] => new DoubleCastedArray(arr) + case arr: Array[Float] => new FloatCastedArray(arr) + case arr: Array[Int] => new IntCastedArray(arr) + case arr: Array[Long] => new LongCastedArray(arr) + case arr: Array[Object] => new ObjectCastedArray(arr) + case arr: Array[Short] => new ShortCastedArray(arr) + case default => throw createBadArrayException(obj) } } @@ -71,46 +71,65 @@ object CastedArray { private class BooleanCastedArray(val arr: Array[Boolean]) extends AnyVal with CastedArray { override def get(i: Int): AnyRef = Boolean.box(arr(i)) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = true + override def getElementSize(): Int = PrimitiveSizes.BOOLEAN_SIZE } private class ByteCastedArray(val arr: Array[Byte]) extends AnyVal with CastedArray { override def get(i: Int): AnyRef = Byte.box(arr(i)) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = true + override def getElementSize(): Int = PrimitiveSizes.BYTE_SIZE } private class CharCastedArray(val arr: Array[Char]) extends AnyVal with CastedArray { override def get(i: Int): AnyRef = Char.box(arr(i)) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = true + override def getElementSize(): Int = PrimitiveSizes.CHAR_SIZE } private class DoubleCastedArray(val arr: Array[Double]) extends AnyVal with CastedArray { override def get(i: Int): AnyRef = Double.box(arr(i)) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = true + override def getElementSize(): Int = PrimitiveSizes.DOUBLE_SIZE } private class FloatCastedArray(val arr: Array[Float]) extends AnyVal with CastedArray { override def get(i: Int): AnyRef = Float.box(arr(i)) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = true + override def getElementSize(): Int = PrimitiveSizes.FLOAT_SIZE } private class IntCastedArray(val arr: Array[Int]) extends AnyVal with CastedArray { override def get(i: Int): AnyRef = Int.box(arr(i)) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = true + override def getElementSize(): Int = PrimitiveSizes.INT_SIZE } private class LongCastedArray(val arr: Array[Long]) extends AnyVal with CastedArray { override def get(i: Int): AnyRef = Long.box(arr(i)) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = true + override def getElementSize(): Int = PrimitiveSizes.LONG_SIZE } private class ObjectCastedArray(val arr: Array[Object]) extends AnyVal with CastedArray { override def get(i: Int): Object = arr(i) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = false + override def getElementSize(): Int = throw new UnsupportedOperationException("Cannot introspect " + + " the size of an element in an object array.") } private class ShortCastedArray(val arr: Array[Short]) extends AnyVal with CastedArray { override def get(i: Int): AnyRef = Short.box(arr(i)) override def getLength(): Int = arr.length + override def isPrimitiveArray(): Boolean = true + override def getElementSize(): Int = PrimitiveSizes.SHORT_SIZE } private def createBadArrayException(badArray : Object): RuntimeException = { diff --git a/core/src/main/scala/org/apache/spark/util/PrimitiveSizes.scala b/core/src/main/scala/org/apache/spark/util/PrimitiveSizes.scala new file mode 100644 index 0000000000000..7d335af07c090 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/PrimitiveSizes.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Constants for the sizes of primitive types in bytes. + */ +object PrimitiveSizes { + val BYTE_SIZE = 1 + val BOOLEAN_SIZE = 1 + val CHAR_SIZE = 2 + val SHORT_SIZE = 2 + val INT_SIZE = 4 + val LONG_SIZE = 8 + val FLOAT_SIZE = 4 + val DOUBLE_SIZE = 8 +} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index e7151a8fd3b0e..f8b21608238cd 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -38,16 +38,6 @@ import org.apache.spark.util.collection.OpenHashSet */ private[spark] object SizeEstimator extends Logging { - // Sizes of primitive types - private val BYTE_SIZE = 1 - private val BOOLEAN_SIZE = 1 - private val CHAR_SIZE = 2 - private val SHORT_SIZE = 2 - private val INT_SIZE = 4 - private val LONG_SIZE = 8 - private val FLOAT_SIZE = 4 - private val DOUBLE_SIZE = 8 - // Alignment boundary for objects // TODO: Is this arch dependent ? private val ALIGN_SIZE = 8 @@ -186,13 +176,12 @@ private[spark] object SizeEstimator extends Logging { private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { val castedArray: CastedArray = CastedArray.castAndWrap(array) val length = castedArray.getLength - val elementClass = cls.getComponentType // Arrays have object header and length field which is an integer - var arrSize: Long = alignSize(objectSize + INT_SIZE) + var arrSize: Long = alignSize(objectSize + PrimitiveSizes.INT_SIZE) - if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) + if (castedArray.isPrimitiveArray()) { + arrSize += alignSize(length * castedArray.getElementSize()) state.size += arrSize } else { arrSize += alignSize(length * pointerSize) @@ -223,21 +212,21 @@ private[spark] object SizeEstimator extends Logging { private def primitiveSize(cls: Class[_]): Long = { if (cls == classOf[Byte]) { - BYTE_SIZE + PrimitiveSizes.BYTE_SIZE } else if (cls == classOf[Boolean]) { - BOOLEAN_SIZE + PrimitiveSizes.BOOLEAN_SIZE } else if (cls == classOf[Char]) { - CHAR_SIZE + PrimitiveSizes.CHAR_SIZE } else if (cls == classOf[Short]) { - SHORT_SIZE + PrimitiveSizes.SHORT_SIZE } else if (cls == classOf[Int]) { - INT_SIZE + PrimitiveSizes.INT_SIZE } else if (cls == classOf[Long]) { - LONG_SIZE + PrimitiveSizes.LONG_SIZE } else if (cls == classOf[Float]) { - FLOAT_SIZE + PrimitiveSizes.FLOAT_SIZE } else if (cls == classOf[Double]) { - DOUBLE_SIZE + PrimitiveSizes.DOUBLE_SIZE } else { throw new IllegalArgumentException( "Non-primitive class " + cls + " passed to primitiveSize()")