Skip to content

Commit

Permalink
Replace boundary on StringByteIndexPrimitiveNode with faster specia…
Browse files Browse the repository at this point in the history
…lizations for boundary checks and single-byte-optimizable strings.
  • Loading branch information
nirvdrum committed Jun 18, 2021
1 parent 58d10fb commit 1fcbd1d
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 21 deletions.
47 changes: 47 additions & 0 deletions src/main/java/org/truffleruby/core/cast/ToRopeNode.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2021 Oracle and/or its affiliates. All rights reserved. This
* code is released under a tri EPL/GPL/LGPL license. You can use it,
* redistribute it and/or modify it under the terms of the:
*
* Eclipse Public License version 2.0, or
* GNU General Public License version 2, or
* GNU Lesser General Public License version 2.1.
*/

package org.truffleruby.core.cast;

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import org.jcodings.Encoding;
import org.truffleruby.core.rope.Rope;
import org.truffleruby.core.string.ImmutableRubyString;
import org.truffleruby.core.string.RubyString;
import org.truffleruby.language.RubyContextSourceNode;
import org.truffleruby.language.RubyNode;

@NodeChild(value = "child", type = RubyNode.class)
public abstract class ToRopeNode extends RubyContextSourceNode {

public abstract Rope executeToRope(Object object);

public static ToRopeNode create() {
return ToRopeNodeGen.create(null);
}

@Specialization
protected Rope coerceRubyString(RubyString string) {
return string.rope;
}

@Specialization
protected Rope coerceImmutableRubyString(ImmutableRubyString string) {
return string.rope;
}

@Fallback
protected Encoding failure(Object value) {
throw CompilerDirectives.shouldNotReachHere();
}
}
96 changes: 75 additions & 21 deletions src/main/java/org/truffleruby/core/string/StringNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
import java.nio.charset.StandardCharsets;

import com.oracle.truffle.api.dsl.Bind;
import com.oracle.truffle.api.nodes.LoopNode;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import org.jcodings.Config;
import org.jcodings.Encoding;
import org.jcodings.specific.ASCIIEncoding;
Expand All @@ -100,6 +102,7 @@
import org.truffleruby.core.cast.BooleanCastNode;
import org.truffleruby.core.cast.ToIntNode;
import org.truffleruby.core.cast.ToLongNode;
import org.truffleruby.core.cast.ToRopeNodeGen;
import org.truffleruby.core.cast.ToStrNode;
import org.truffleruby.core.cast.ToStrNodeGen;
import org.truffleruby.core.encoding.EncodingLeftCharHeadNode;
Expand Down Expand Up @@ -4461,42 +4464,93 @@ protected Object stringCharacterIndex(Object string, Object pattern, int offset,
}

@Primitive(name = "string_byte_index", lowerFixnum = 2)
public abstract static class StringByteIndexPrimitiveNode extends PrimitiveArrayArgumentsNode {
@NodeChild(value = "string", type = RubyNode.class)
@NodeChild(value = "pattern", type = RubyNode.class)
@NodeChild(value = "offset", type = RubyNode.class)
public abstract static class StringByteIndexPrimitiveNode extends PrimitiveNode {

@TruffleBoundary
@Specialization
protected Object stringCharacterIndex(Object string, Object pattern, int offset,
@Cached RopeNodes.CalculateCharacterLengthNode calculateCharacterLengthNode,
@CachedLibrary(limit = "2") RubyStringLibrary libString,
@CachedLibrary(limit = "2") RubyStringLibrary libPattern) {
if (offset < 0) {
return nil;
}
@Child SingleByteOptimizableNode singleByteOptimizableNode = SingleByteOptimizableNode.create();

final Rope stringRope = libString.getRope(string);
final Rope patternRope = libPattern.getRope(pattern);
@CreateCast("string")
protected RubyNode coerceStringToRope(RubyNode string) {
return ToRopeNodeGen.create(string);
}

final int total = stringRope.byteLength();
int p = 0;
final int e = p + total;
@CreateCast("pattern")
protected RubyNode coercePatternToRope(RubyNode pattern) {
return ToRopeNodeGen.create(pattern);
}

@Specialization(guards = "offset < 0")
protected Object stringByteIndexNegativeOffset(Rope stringRope, Rope patternRope, int offset) {
return nil;
}

@Specialization(
guards = {
"offset >= 0",
"singleByteOptimizableNode.execute(stringRope)",
"patternRope.byteLength() > stringRope.byteLength()" })
protected Object stringByteIndexPatternTooLarge(Rope stringRope, Rope patternRope, int offset) {
return nil;
}

@Specialization(
guards = {
"offset >= 0",
"singleByteOptimizableNode.execute(stringRope)",
"patternRope.byteLength() <= stringRope.byteLength()" })
protected Object stringCharacterIndexSingleByteOptimizable(Rope stringRope, Rope patternRope, int offset,
@Cached BranchProfile matchProfile,
@Cached BranchProfile noMatchProfile,
@Cached RopeNodes.BytesNode stringBytesNode,
@Cached RopeNodes.BytesNode patternBytesNode,
@Cached LoopConditionProfile loopProfile) {

int p = offset;
final int e = stringRope.byteLength();
final int pe = patternRope.byteLength();
final int l = e - pe + 1;

final byte[] stringBytes = stringRope.getBytes();
final byte[] patternBytes = patternRope.getBytes();
final byte[] stringBytes = stringBytesNode.execute(stringRope);
final byte[] patternBytes = patternBytesNode.execute(patternRope);

p += offset;
try {
loopProfile.profileCounted(l - p);

if (stringRope.isSingleByteOptimizable()) {
for (; p < l; p++) {
if (ArrayUtils.memcmp(stringBytes, p, patternBytes, 0, pe) == 0) {
matchProfile.enter();
return p;
}
}

return nil;
} finally {
LoopNode.reportLoopCount(this, p - offset);
}

noMatchProfile.enter();
return nil;
}

@TruffleBoundary
@Specialization(
guards = {
"offset >= 0",
"!singleByteOptimizableNode.execute(stringRope)",
"patternRope.byteLength() <= stringRope.byteLength()" })
protected Object stringCharacterIndex(Rope stringRope, Rope patternRope, int offset,
@Cached RopeNodes.CalculateCharacterLengthNode calculateCharacterLengthNode,
@Cached RopeNodes.BytesNode stringBytesNode,
@Cached RopeNodes.BytesNode patternBytesNode) {

int p = offset;
final int e = stringRope.byteLength();
final int pe = patternRope.byteLength();
final int l = e - pe + 1;

final byte[] stringBytes = stringBytesNode.execute(stringRope);
final byte[] patternBytes = patternBytesNode.execute(patternRope);

final Encoding enc = stringRope.getEncoding();
final CodeRange cr = stringRope.getCodeRange();
int c = 0;
Expand Down

0 comments on commit 1fcbd1d

Please sign in to comment.