From 1fe09de24d48cdfbff995a7ef1f82d72ff521509 Mon Sep 17 00:00:00 2001 From: Justin Uang Date: Thu, 1 Jan 2015 12:27:11 -0800 Subject: [PATCH] [SPARK-6269] Use a different implementation of java.lang.reflect.Array This patch uses a different implementation of java.lang.reflect.Array. The code is copied and pasted from https://bugs.openjdk.java.net/browse/JDK-8051447. The appropriate code is in the public domain, so we can use it. The idea is to use pure Java code in implementing the methods there, as opposed to relying on native C code which ends up being ill-performing. This greatly improves the performance of estimating the size of arrays when we are checking for spilling in Spark. Benchmark results are in the bug description and in the discussion on the linked pull request. --- .../org/apache/spark/util/ArrayReflect.java | 359 ++++++++++++++++++ .../org/apache/spark/util/SizeEstimator.scala | 2 +- 2 files changed, 360 insertions(+), 1 deletion(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ArrayReflect.java diff --git a/core/src/main/scala/org/apache/spark/util/ArrayReflect.java b/core/src/main/scala/org/apache/spark/util/ArrayReflect.java new file mode 100644 index 0000000000000..bbf69557a9367 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ArrayReflect.java @@ -0,0 +1,359 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.util; + +/** + * Alternative implementation of {@link java.lang.reflect.Array} that is more performant + * over calling to native C in many cases. Based on code from + * an open JDK ticket. + */ +public class ArrayReflect { + /* + * Legal conversions: + * From \ To| boolean | byte | char | short | int | long | float | double + * ----------|------------------------------------------------------------- + * boolean | # + * byte | # # # # # # + * char | # # # # # + * short | # # # # # + * int | # # # # + * long | # # # + * float | # # + * double | # + */ + + private static RuntimeException badArray(Object array) { + if (array == null) { + return new NullPointerException("Array argument is null"); + } else if (!array.getClass().isArray()) { + return new IllegalArgumentException("Argument is not an array"); + } else { + return new IllegalArgumentException("Array is of incompatible type"); + } + } + + public static int getLength(Object array) { + // Note: all types of multidimensional arrays are instanceof Object[] + if (array instanceof Object[]) { + return ((Object[])array).length; + } + if (array instanceof boolean[]) { + return ((boolean[]) array).length; + } + if (array instanceof byte[]) { + return ((byte[])array).length; + } + if (array instanceof char[]) { + return ((char[])array).length; + } + if (array instanceof short[]) { + return ((short[])array).length; + } + if (array instanceof int[]) { + return ((int[])array).length; + } + if (array instanceof long[]) { + return ((long[])array).length; + } + if (array instanceof float[]) { + return ((float[])array).length; + } + if (array instanceof double[]) { + return ((double[])array).length; + } + throw badArray(array); + } + + public static boolean getBoolean(Object array, int index) { + if (array instanceof boolean[]) { + return ((boolean[])array)[index]; + } + throw badArray(array); + } + + public static byte getByte(Object array, int index) { + if (array instanceof byte[]) { + return ((byte[])array)[index]; + } + throw badArray(array); + } + + public static char getChar(Object array, int index) { + if (array instanceof char[]) { + return ((char[])array)[index]; + } + throw badArray(array); + } + + public static short getShort(Object array, int index) { + if (array instanceof short[]) { + return ((short[])array)[index]; + } + if (array instanceof byte[]) { + return ((byte[])array)[index]; + } + throw badArray(array); + } + + public static int getInt(Object array, int index) { + if (array instanceof int[]) { + return ((int[]) array)[index]; + } + if (array instanceof short[]) { + return ((short[])array)[index]; + } + if (array instanceof char[]) { + return ((char[])array)[index]; + } + if (array instanceof byte[]) { + return ((byte[])array)[index]; + } + throw badArray(array); + } + + public static long getLong(Object array, int index) { + if (array instanceof long[]) { + return ((long[])array)[index]; + } + if (array instanceof int[]) { + return ((int[])array)[index]; + } + if (array instanceof short[]) { + return ((short[])array)[index]; + } + if (array instanceof char[]) { + return ((char[])array)[index]; + } + if (array instanceof byte[]) { + return ((byte[])array)[index]; + } + throw badArray(array); + } + + public static float getFloat(Object array, int index) { + if (array instanceof float[]) { + return ((float[])array)[index]; + } + if (array instanceof long[]) { + return ((long[])array)[index]; + } + if (array instanceof int[]) { + return ((int[])array)[index]; + } + if (array instanceof short[]) { + return ((short[])array)[index]; + } + if (array instanceof char[]) { + return ((char[])array)[index]; + } + if (array instanceof byte[]) { + return ((byte[]) array)[index]; + } + throw badArray(array); + } + + public static double getDouble(Object array, int index) { + if (array instanceof double[]) { + return ((double[])array)[index]; + } + if (array instanceof float[]) { + return ((float[])array)[index]; + } + if (array instanceof long[]) { + return ((long[])array)[index]; + } + if (array instanceof int[]) { + return ((int[])array)[index]; + } + if (array instanceof short[]) { + return ((short[])array)[index]; + } + if (array instanceof char[]) { + return ((char[])array)[index]; + } + if (array instanceof byte[]) { + return ((byte[])array)[index]; + } + throw badArray(array); + } + + public static Object get(Object array, int index) { + if (array instanceof Object[]) { + return ((Object[])array)[index]; + } + if (array instanceof boolean[]) { + return ((boolean[])array)[index]; + } + if (array instanceof byte[]) { + return ((byte[])array)[index]; + } + if (array instanceof char[]) { + return ((char[])array)[index]; + } + if (array instanceof short[]) { + return ((short[])array)[index]; + } + if (array instanceof int[]) { + return ((int[])array)[index]; + } + if (array instanceof long[]) { + return ((long[])array)[index]; + } + if (array instanceof float[]) { + return ((float[])array)[index]; + } + if (array instanceof double[]) { + return ((double[])array)[index]; + } + throw badArray(array); + } + + public static void setBoolean(Object array, int index, boolean z) { + if (array instanceof boolean[]) { + ((boolean[])array)[index] = z; + } else { + throw badArray(array); + } + } + + public static void setByte(Object array, int index, byte b) { + if (array instanceof byte[]) { + ((byte[])array)[index] = b; + } else if (array instanceof short[]) { + ((short[])array)[index] = b; + } else if (array instanceof int[]) { + ((int[])array)[index] = b; + } else if (array instanceof long[]) { + ((long[])array)[index] = b; + } else if (array instanceof float[]) { + ((float[])array)[index] = b; + } else if (array instanceof double[]) { + ((double[])array)[index] = b; + } else { + throw badArray(array); + } + } + + public static void setChar(Object array, int index, char c) { + if (array instanceof char[]) { + ((char[])array)[index] = c; + } else if (array instanceof int[]) { + ((int[])array)[index] = c; + } else if (array instanceof long[]) { + ((long[])array)[index] = c; + } else if (array instanceof float[]) { + ((float[])array)[index] = c; + } else if (array instanceof double[]) { + ((double[])array)[index] = c; + } else { + throw badArray(array); + } + } + + public static void setShort(Object array, int index, short s) { + if (array instanceof short[]) { + ((short[])array)[index] = s; + } else if (array instanceof int[]) { + ((int[])array)[index] = s; + } else if (array instanceof long[]) { + ((long[])array)[index] = s; + } else if (array instanceof float[]) { + ((float[])array)[index] = s; + } else if (array instanceof double[]) { + ((double[])array)[index] = s; + } else { + throw badArray(array); + } + } + + public static void setInt(Object array, int index, int i) { + if (array instanceof int[]) { + ((int[])array)[index] = i; + } else if (array instanceof long[]) { + ((long[])array)[index] = i; + } else if (array instanceof float[]) { + ((float[])array)[index] = i; + } else if (array instanceof double[]) { + ((double[])array)[index] = i; + } else { + throw badArray(array); + } + } + + public static void setLong(Object array, int index, long l) { + if (array instanceof long[]) { + ((long[])array)[index] = l; + } else if (array instanceof float[]) { + ((float[])array)[index] = l; + } else if (array instanceof double[]) { + ((double[])array)[index] = l; + } else { + throw badArray(array); + } + } + + public static void setFloat(Object array, int index, float f) { + if (array instanceof float[]) { + ((float[])array)[index] = f; + } else { + if (array instanceof double[]) { + ((double[])array)[index] = f; + } else { + throw badArray(array); + } + } + + + public static void setDouble(Object array, int index, double d) { + if (array instanceof double[]) { + ((double[])array)[index] = d; + } else { + throw badArray(array); + } + } + + public static void set(Object array, int index, Object value) { + if (array instanceof Object[]) { + try { + ((Object[])array)[index] = value; + } catch (ArrayStoreException e) { + throw badArray(array); + } + } else { + if (value instanceof Boolean) { + setBoolean(array, index, ((Boolean) value).booleanValue()); + } else if (value instanceof Byte) { + setByte(array, index, ((Byte) value).byteValue()); + } else if (value instanceof Short) { + setShort(array, index, ((Short) value).shortValue()); + } else if (value instanceof Character) { + setChar(array, index, ((Character) value).charValue()); + } else if (value instanceof Integer) { + setInt(array, index, ((Integer) value).intValue()); + } else if (value instanceof Long) { + setLong(array, index, ((Long) value).longValue()); + } else if (value instanceof Float) { + setFloat(array, index, ((Float) value).floatValue()); + } else if (value instanceof Double) { + setDouble(array, index, ((Double) value).doubleValue()); + } else { + throw badArray(array); + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index bce3b3afe9aba..53386c9a9a4e0 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.lang.management.ManagementFactory -import java.lang.reflect.{Array => JArray} +import org.apache.spark.util.{ArrayReflect => JArray} import java.lang.reflect.Field import java.lang.reflect.Modifier import java.util.IdentityHashMap