Skip to content

Commit

Permalink
Including primitive size information inside CastedArray.
Browse files Browse the repository at this point in the history
  • Loading branch information
mccheah committed Mar 12, 2015
1 parent 93f4b05 commit 6467759
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 43 deletions.
61 changes: 40 additions & 21 deletions core/src/main/scala/org/apache/spark/util/CastedArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,34 @@ package org.apache.spark.util
* is also avoided. It turns out we invoke the get() method to get the value of the array
* numerous times, so doing the cast just once is worth the cost of constructing the wrapper
* object for larger arrays.
*
* In general, these classes were designed to avoid the need to cast as much as possible. As
* soon as the type of the array is known, it is casted immediately once and all of its metadata
* (primitive type size, length, and whether or not it is a primitive array) is available
* immediately without any further reflection or introspecting on class objects.
*/
sealed trait CastedArray extends Any {
def get(i: Int): AnyRef
def getLength(): Int
def isPrimitiveArray(): Boolean
def getElementSize(): Int
}

object CastedArray {
def castAndWrap(arr: AnyRef): CastedArray = {
if (arr.isInstanceOf[Array[Boolean]]) {
return new BooleanCastedArray(arr.asInstanceOf[Array[Boolean]])
} else if (arr.isInstanceOf[Array[Byte]]) {
return new ByteCastedArray(arr.asInstanceOf[Array[Byte]])
} else if (arr.isInstanceOf[Array[Char]]) {
return new CharCastedArray(arr.asInstanceOf[Array[Char]])
} else if (arr.isInstanceOf[Array[Double]]) {
return new DoubleCastedArray(arr.asInstanceOf[Array[Double]])
} else if (arr.isInstanceOf[Array[Float]]) {
return new FloatCastedArray(arr.asInstanceOf[Array[Float]])
} else if (arr.isInstanceOf[Array[Int]]) {
return new IntCastedArray(arr.asInstanceOf[Array[Int]])
} else if (arr.isInstanceOf[Array[Long]]) {
return new LongCastedArray(arr.asInstanceOf[Array[Long]])
} else if (arr.isInstanceOf[Array[Object]]) {
return new ObjectCastedArray(arr.asInstanceOf[Array[Object]])
} else if (arr.isInstanceOf[Array[Short]]) {
return new ShortCastedArray(arr.asInstanceOf[Array[Short]])
} else {
throw createBadArrayException(arr)
// Sizes of primitive types

def castAndWrap(obj: AnyRef): CastedArray = {
obj match {
case arr: Array[Boolean] => new BooleanCastedArray(arr)
case arr: Array[Byte] => new ByteCastedArray(arr)
case arr: Array[Char] => new CharCastedArray(arr)
case arr: Array[Double] => new DoubleCastedArray(arr)
case arr: Array[Float] => new FloatCastedArray(arr)
case arr: Array[Int] => new IntCastedArray(arr)
case arr: Array[Long] => new LongCastedArray(arr)
case arr: Array[Object] => new ObjectCastedArray(arr)
case arr: Array[Short] => new ShortCastedArray(arr)
case default => throw createBadArrayException(obj)
}
}

Expand All @@ -71,46 +71,65 @@ object CastedArray {
private class BooleanCastedArray(val arr: Array[Boolean]) extends AnyVal with CastedArray {
override def get(i: Int): AnyRef = Boolean.box(arr(i))
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = true
override def getElementSize(): Int = PrimitiveSizes.BOOLEAN_SIZE
}

private class ByteCastedArray(val arr: Array[Byte]) extends AnyVal with CastedArray {
override def get(i: Int): AnyRef = Byte.box(arr(i))
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = true
override def getElementSize(): Int = PrimitiveSizes.BYTE_SIZE
}

private class CharCastedArray(val arr: Array[Char]) extends AnyVal with CastedArray {
override def get(i: Int): AnyRef = Char.box(arr(i))
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = true
override def getElementSize(): Int = PrimitiveSizes.CHAR_SIZE
}

private class DoubleCastedArray(val arr: Array[Double]) extends AnyVal with CastedArray {
override def get(i: Int): AnyRef = Double.box(arr(i))
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = true
override def getElementSize(): Int = PrimitiveSizes.DOUBLE_SIZE
}

private class FloatCastedArray(val arr: Array[Float]) extends AnyVal with CastedArray {
override def get(i: Int): AnyRef = Float.box(arr(i))
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = true
override def getElementSize(): Int = PrimitiveSizes.FLOAT_SIZE
}

private class IntCastedArray(val arr: Array[Int]) extends AnyVal with CastedArray {
override def get(i: Int): AnyRef = Int.box(arr(i))
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = true
override def getElementSize(): Int = PrimitiveSizes.INT_SIZE
}

private class LongCastedArray(val arr: Array[Long]) extends AnyVal with CastedArray {
override def get(i: Int): AnyRef = Long.box(arr(i))
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = true
override def getElementSize(): Int = PrimitiveSizes.LONG_SIZE
}

private class ObjectCastedArray(val arr: Array[Object]) extends AnyVal with CastedArray {
override def get(i: Int): Object = arr(i)
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = false
override def getElementSize(): Int = throw new UnsupportedOperationException("Cannot introspect " +
" the size of an element in an object array.")
}

private class ShortCastedArray(val arr: Array[Short]) extends AnyVal with CastedArray {
override def get(i: Int): AnyRef = Short.box(arr(i))
override def getLength(): Int = arr.length
override def isPrimitiveArray(): Boolean = true
override def getElementSize(): Int = PrimitiveSizes.SHORT_SIZE
}

private def createBadArrayException(badArray : Object): RuntimeException = {
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/scala/org/apache/spark/util/PrimitiveSizes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

/**
* Constants for the sizes of primitive types in bytes.
*/
object PrimitiveSizes {
val BYTE_SIZE = 1
val BOOLEAN_SIZE = 1
val CHAR_SIZE = 2
val SHORT_SIZE = 2
val INT_SIZE = 4
val LONG_SIZE = 8
val FLOAT_SIZE = 4
val DOUBLE_SIZE = 8
}
33 changes: 11 additions & 22 deletions core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,6 @@ import org.apache.spark.util.collection.OpenHashSet
*/
private[spark] object SizeEstimator extends Logging {

// Sizes of primitive types
private val BYTE_SIZE = 1
private val BOOLEAN_SIZE = 1
private val CHAR_SIZE = 2
private val SHORT_SIZE = 2
private val INT_SIZE = 4
private val LONG_SIZE = 8
private val FLOAT_SIZE = 4
private val DOUBLE_SIZE = 8

// Alignment boundary for objects
// TODO: Is this arch dependent ?
private val ALIGN_SIZE = 8
Expand Down Expand Up @@ -186,13 +176,12 @@ private[spark] object SizeEstimator extends Logging {
private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) {
val castedArray: CastedArray = CastedArray.castAndWrap(array)
val length = castedArray.getLength
val elementClass = cls.getComponentType

// Arrays have object header and length field which is an integer
var arrSize: Long = alignSize(objectSize + INT_SIZE)
var arrSize: Long = alignSize(objectSize + PrimitiveSizes.INT_SIZE)

if (elementClass.isPrimitive) {
arrSize += alignSize(length * primitiveSize(elementClass))
if (castedArray.isPrimitiveArray()) {
arrSize += alignSize(length * castedArray.getElementSize())
state.size += arrSize
} else {
arrSize += alignSize(length * pointerSize)
Expand Down Expand Up @@ -223,21 +212,21 @@ private[spark] object SizeEstimator extends Logging {

private def primitiveSize(cls: Class[_]): Long = {
if (cls == classOf[Byte]) {
BYTE_SIZE
PrimitiveSizes.BYTE_SIZE
} else if (cls == classOf[Boolean]) {
BOOLEAN_SIZE
PrimitiveSizes.BOOLEAN_SIZE
} else if (cls == classOf[Char]) {
CHAR_SIZE
PrimitiveSizes.CHAR_SIZE
} else if (cls == classOf[Short]) {
SHORT_SIZE
PrimitiveSizes.SHORT_SIZE
} else if (cls == classOf[Int]) {
INT_SIZE
PrimitiveSizes.INT_SIZE
} else if (cls == classOf[Long]) {
LONG_SIZE
PrimitiveSizes.LONG_SIZE
} else if (cls == classOf[Float]) {
FLOAT_SIZE
PrimitiveSizes.FLOAT_SIZE
} else if (cls == classOf[Double]) {
DOUBLE_SIZE
PrimitiveSizes.DOUBLE_SIZE
} else {
throw new IllegalArgumentException(
"Non-primitive class " + cls + " passed to primitiveSize()")
Expand Down

0 comments on commit 6467759

Please sign in to comment.