From 9be5769a68d7cc9dced732649adc4c977b8eac8d Mon Sep 17 00:00:00 2001 From: Jan Lahoda Date: Mon, 5 Jun 2023 10:48:25 +0000 Subject: [PATCH] 8291966: SwitchBootstrap.typeSwitch could be faster Reviewed-by: asotona --- .../java/lang/runtime/SwitchBootstraps.java | 262 ++++++++++++------ .../lang/runtime/SwitchBootstrapsTest.java | 7 + 2 files changed, 182 insertions(+), 87 deletions(-) diff --git a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java index 4ac90d355035c..4743b00997030 100644 --- a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java +++ b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2017, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -27,16 +27,17 @@ import java.lang.Enum.EnumDesc; import java.lang.invoke.CallSite; -import java.lang.invoke.ConstantBootstraps; import java.lang.invoke.ConstantCallSite; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; +import java.util.List; import java.util.Objects; import java.util.stream.Stream; +import jdk.internal.access.SharedSecrets; +import jdk.internal.vm.annotation.Stable; import static java.util.Objects.requireNonNull; -import jdk.internal.vm.annotation.Stable; /** * Bootstrap methods for linking {@code invokedynamic} call sites that implement @@ -50,18 +51,38 @@ public class SwitchBootstraps { private SwitchBootstraps() {} + private static final Object SENTINEL = new Object(); private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); - private static final MethodHandle DO_TYPE_SWITCH; - private static final MethodHandle DO_ENUM_SWITCH; + private static final MethodHandle INSTANCEOF_CHECK; + private static final MethodHandle INTEGER_EQ_CHECK; + private static final MethodHandle OBJECT_EQ_CHECK; + private static final MethodHandle ENUM_EQ_CHECK; + private static final MethodHandle NULL_CHECK; + private static final MethodHandle IS_ZERO; + private static final MethodHandle CHECK_INDEX; + private static final MethodHandle MAPPED_ENUM_LOOKUP; static { try { - DO_TYPE_SWITCH = LOOKUP.findStatic(SwitchBootstraps.class, "doTypeSwitch", - MethodType.methodType(int.class, Object.class, int.class, Object[].class)); - DO_ENUM_SWITCH = LOOKUP.findStatic(SwitchBootstraps.class, "doEnumSwitch", - MethodType.methodType(int.class, Enum.class, int.class, Object[].class, - MethodHandles.Lookup.class, Class.class, ResolvedEnumLabels.class)); + INSTANCEOF_CHECK = MethodHandles.permuteArguments(LOOKUP.findVirtual(Class.class, "isInstance", + MethodType.methodType(boolean.class, Object.class)), + MethodType.methodType(boolean.class, Object.class, Class.class), 1, 0); + INTEGER_EQ_CHECK = LOOKUP.findStatic(SwitchBootstraps.class, "integerEqCheck", + MethodType.methodType(boolean.class, Object.class, Integer.class)); + OBJECT_EQ_CHECK = LOOKUP.findStatic(Objects.class, "equals", + MethodType.methodType(boolean.class, Object.class, Object.class)); + ENUM_EQ_CHECK = LOOKUP.findStatic(SwitchBootstraps.class, "enumEqCheck", + MethodType.methodType(boolean.class, Object.class, EnumDesc.class, MethodHandles.Lookup.class, ResolvedEnumLabel.class)); + NULL_CHECK = LOOKUP.findStatic(Objects.class, "isNull", + MethodType.methodType(boolean.class, Object.class)); + IS_ZERO = LOOKUP.findStatic(SwitchBootstraps.class, "isZero", + MethodType.methodType(boolean.class, int.class)); + CHECK_INDEX = LOOKUP.findStatic(Objects.class, "checkIndex", + MethodType.methodType(int.class, int.class, int.class)); + MAPPED_ENUM_LOOKUP = LOOKUP.findStatic(SwitchBootstraps.class, "mappedEnumLookup", + MethodType.methodType(int.class, Enum.class, MethodHandles.Lookup.class, + Class.class, EnumDesc[].class, EnumMap.class)); } catch (ReflectiveOperationException e) { throw new ExceptionInInitializerError(e); @@ -134,7 +155,8 @@ public static CallSite typeSwitch(MethodHandles.Lookup lookup, labels = labels.clone(); Stream.of(labels).forEach(SwitchBootstraps::verifyLabel); - MethodHandle target = MethodHandles.insertArguments(DO_TYPE_SWITCH, 2, (Object) labels); + MethodHandle target = createMethodHandleSwitch(lookup, labels); + return new ConstantCallSite(target); } @@ -151,36 +173,81 @@ private static void verifyLabel(Object label) { } } - private static int doTypeSwitch(Object target, int startIndex, Object[] labels) { - Objects.checkIndex(startIndex, labels.length + 1); - - if (target == null) - return -1; - - // Dumbest possible strategy - Class targetClass = target.getClass(); - for (int i = startIndex; i < labels.length; i++) { - Object label = labels[i]; - if (label instanceof Class c) { - if (c.isAssignableFrom(targetClass)) - return i; - } else if (label instanceof Integer constant) { - if (target instanceof Number input && constant.intValue() == input.intValue()) { - return i; - } else if (target instanceof Character input && constant.intValue() == input.charValue()) { - return i; - } - } else if (label instanceof EnumDesc enumDesc) { - if (target.getClass().isEnum() && - ((Enum) target).describeConstable().stream().anyMatch(d -> d.equals(enumDesc))) { - return i; + /* + * Construct test chains for labels inside switch, to handle switch repeats: + * switch (idx) { + * case 0 -> if (selector matches label[0]) return 0; else if (selector matches label[1]) return 1; else ... + * case 1 -> if (selector matches label[1]) return 1; else ... + * ... + * } + */ + private static MethodHandle createRepeatIndexSwitch(MethodHandles.Lookup lookup, Object[] labels) { + MethodHandle def = MethodHandles.dropArguments(MethodHandles.constant(int.class, labels.length), 0, Object.class); + MethodHandle[] testChains = new MethodHandle[labels.length]; + List labelsList = List.of(labels).reversed(); + + for (int i = 0; i < labels.length; i++) { + MethodHandle test = def; + int idx = labels.length - 1; + List currentLabels = labelsList.subList(0, labels.length - i); + + for (int j = 0; j < currentLabels.size(); j++, idx--) { + Object currentLabel = currentLabels.get(j); + if (j + 1 < currentLabels.size() && currentLabels.get(j + 1) == currentLabel) continue; + MethodHandle currentTest; + if (currentLabel instanceof Class) { + currentTest = INSTANCEOF_CHECK; + } else if (currentLabel instanceof Integer) { + currentTest = INTEGER_EQ_CHECK; + } else if (currentLabel instanceof EnumDesc) { + currentTest = MethodHandles.insertArguments(ENUM_EQ_CHECK, 2, lookup, new ResolvedEnumLabel()); + } else { + currentTest = OBJECT_EQ_CHECK; } - } else if (label.equals(target)) { - return i; + test = MethodHandles.guardWithTest(MethodHandles.insertArguments(currentTest, 1, currentLabel), + MethodHandles.dropArguments(MethodHandles.constant(int.class, idx), 0, Object.class), + test); } + testChains[i] = MethodHandles.dropArguments(test, 0, int.class); + } + + return MethodHandles.tableSwitch(MethodHandles.dropArguments(def, 0, int.class), testChains); + } + + /* + * Construct code that maps the given selector and repeat index to a case label number: + * if (selector == null) return -1; + * else return "createRepeatIndexSwitch(labels)" + */ + private static MethodHandle createMethodHandleSwitch(MethodHandles.Lookup lookup, Object[] labels) { + MethodHandle mainTest; + MethodHandle def = MethodHandles.dropArguments(MethodHandles.constant(int.class, labels.length), 0, Object.class); + if (labels.length > 0) { + mainTest = createRepeatIndexSwitch(lookup, labels); + } else { + mainTest = MethodHandles.dropArguments(def, 0, int.class); + } + MethodHandle body = + MethodHandles.guardWithTest(MethodHandles.dropArguments(NULL_CHECK, 0, int.class), + MethodHandles.dropArguments(MethodHandles.constant(int.class, -1), 0, int.class, Object.class), + mainTest); + MethodHandle switchImpl = + MethodHandles.permuteArguments(body, MethodType.methodType(int.class, Object.class, int.class), 1, 0); + return withIndexCheck(switchImpl, labels.length); + } + + private static boolean integerEqCheck(Object value, Integer constant) { + if (value instanceof Number input && constant.intValue() == input.intValue()) { + return true; + } else if (value instanceof Character input && constant.intValue() == input.charValue()) { + return true; } - return labels.length; + return false; + } + + private static boolean isZero(int value) { + return value == 0; } /** @@ -254,28 +321,31 @@ public static CallSite enumSwitch(MethodHandles.Lookup lookup, labels = labels.clone(); Class enumClass = invocationType.parameterType(0); - Stream.of(labels).forEach(l -> validateEnumLabel(enumClass, l)); - MethodHandle temporary = - MethodHandles.insertArguments(DO_ENUM_SWITCH, 2, labels, lookup, enumClass, new ResolvedEnumLabels()); - temporary = temporary.asType(invocationType); + labels = Stream.of(labels).map(l -> convertEnumConstants(lookup, enumClass, l)).toArray(); - return new ConstantCallSite(temporary); - } + MethodHandle target; + boolean constantsOnly = Stream.of(labels).allMatch(l -> enumClass.isAssignableFrom(EnumDesc.class)); - private static > void validateEnumLabel(Class enumClassTemplate, Object label) { - if (label == null) { - throw new IllegalArgumentException("null label found"); - } - Class labelClass = label.getClass(); - if (labelClass == Class.class) { - if (label != enumClassTemplate) { - throw new IllegalArgumentException("the Class label: " + label + - ", expected the provided enum class: " + enumClassTemplate); - } - } else if (labelClass != String.class) { - throw new IllegalArgumentException("label with illegal type found: " + labelClass + - ", expected label of type either String or Class"); + if (labels.length > 0 && constantsOnly) { + //If all labels are enum constants, construct an optimized handle for repeat index 0: + //if (selector == null) return -1 + //else if (idx == 0) return mappingArray[selector.ordinal()]; //mapping array created lazily + //else return "createRepeatIndexSwitch(labels)" + MethodHandle body = + MethodHandles.guardWithTest(MethodHandles.dropArguments(NULL_CHECK, 0, int.class), + MethodHandles.dropArguments(MethodHandles.constant(int.class, -1), 0, int.class, Object.class), + MethodHandles.guardWithTest(MethodHandles.dropArguments(IS_ZERO, 1, Object.class), + createRepeatIndexSwitch(lookup, labels), + MethodHandles.insertArguments(MAPPED_ENUM_LOOKUP, 1, lookup, enumClass, labels, new EnumMap()))); + target = MethodHandles.permuteArguments(body, MethodType.methodType(int.class, Object.class, int.class), 1, 0); + } else { + target = createMethodHandleSwitch(lookup, labels); } + + target = target.asType(invocationType); + target = withIndexCheck(target, labels.length); + + return new ConstantCallSite(target); } private static > Object convertEnumConstants(MethodHandles.Lookup lookup, Class enumClassTemplate, Object label) { @@ -290,52 +360,70 @@ private static > Object convertEnumConstants(MethodHandles.Loo } return label; } else if (labelClass == String.class) { - @SuppressWarnings("unchecked") - Class enumClass = (Class) enumClassTemplate; - try { - return ConstantBootstraps.enumConstant(lookup, (String) label, enumClass); - } catch (IllegalArgumentException ex) { - return null; - } + return EnumDesc.of(enumClassTemplate.describeConstable().get(), (String) label); } else { throw new IllegalArgumentException("label with illegal type found: " + labelClass + ", expected label of type either String or Class"); } } - private static int doEnumSwitch(Enum target, int startIndex, Object[] unresolvedLabels, - MethodHandles.Lookup lookup, Class enumClass, - ResolvedEnumLabels resolvedLabels) { - Objects.checkIndex(startIndex, unresolvedLabels.length + 1); + private static > int mappedEnumLookup(T value, MethodHandles.Lookup lookup, Class enumClass, EnumDesc[] labels, EnumMap enumMap) { + if (enumMap.map == null) { + T[] constants = SharedSecrets.getJavaLangAccess().getEnumConstantsShared(enumClass); + int[] map = new int[constants.length]; + int ordinal = 0; - if (target == null) - return -1; + for (T constant : constants) { + map[ordinal] = labels.length; - if (resolvedLabels.resolvedLabels == null) { - resolvedLabels.resolvedLabels = Stream.of(unresolvedLabels) - .map(l -> convertEnumConstants(lookup, enumClass, l)) - .toArray(); + for (int i = 0; i < labels.length; i++) { + if (Objects.equals(labels[i].constantName(), constant.name())) { + map[ordinal] = i; + break; + } + } + + ordinal++; + } } + return enumMap.map[value.ordinal()]; + } + + private static boolean enumEqCheck(Object value, EnumDesc label, MethodHandles.Lookup lookup, ResolvedEnumLabel resolvedEnum) { + if (resolvedEnum.resolvedEnum == null) { + Object resolved; + + try { + Class clazz = label.constantType().resolveConstantDesc(lookup); - Object[] labels = resolvedLabels.resolvedLabels; - - // Dumbest possible strategy - Class targetClass = target.getClass(); - for (int i = startIndex; i < labels.length; i++) { - Object label = labels[i]; - if (label instanceof Class c) { - if (c.isAssignableFrom(targetClass)) - return i; - } else if (label == target) { - return i; + if (value.getClass() != clazz) { + return false; + } + + resolved = label.resolveConstantDesc(lookup); + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + resolved = SENTINEL; } + + resolvedEnum.resolvedEnum = resolved; } - return labels.length; + return value == resolvedEnum.resolvedEnum; + } + + private static MethodHandle withIndexCheck(MethodHandle target, int labelsCount) { + MethodHandle checkIndex = MethodHandles.insertArguments(CHECK_INDEX, 1, labelsCount + 1); + + return MethodHandles.filterArguments(target, 1, checkIndex); + } + + private static final class ResolvedEnumLabel { + @Stable + public Object resolvedEnum; } - private static final class ResolvedEnumLabels { + private static final class EnumMap { @Stable - public Object[] resolvedLabels; + public int[] map; } } diff --git a/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java b/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java index f69a9b78cefd4..2820ad5a01150 100644 --- a/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java +++ b/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java @@ -113,6 +113,7 @@ public void testTypes() throws Throwable { testType("", 0, 0, String.class, String.class, String.class); testType("", 1, 1, String.class, String.class, String.class); testType("", 2, 2, String.class, String.class, String.class); + testType("", 0, 0); } public void testEnums() throws Throwable { @@ -131,6 +132,12 @@ public void testEnums() throws Throwable { } catch (IllegalArgumentException ex) { //OK } + testEnum(E1.B, 0, 0, "B", "A"); + testEnum(E1.A, 0, 1, "B", "A"); + testEnum(E1.A, 0, 0, "A", "A", "B"); + testEnum(E1.A, 1, 1, "A", "A", "B"); + testEnum(E1.A, 2, 3, "A", "A", "B"); + testEnum(E1.A, 0, 0); } public void testWrongSwitchTypes() throws Throwable {