Skip to content

Commit

Permalink
[GR-19220] TRegex splitting (#2389)
Browse files Browse the repository at this point in the history
PullRequest: truffleruby/2767
  • Loading branch information
eregon committed Jun 26, 2021
2 parents 9752324 + 405a048 commit 6eaa842
Show file tree
Hide file tree
Showing 9 changed files with 471 additions and 410 deletions.
2 changes: 1 addition & 1 deletion lib/truffle/strscan.rb
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_byte
end

# We need to match one byte, regardless of the string encoding
@match = Primitive.matchdata_create(/./mn, @string, [pos], [pos+1])
@match = Primitive.matchdata_create_single_group(/./mn, @string, pos, pos+1)

@prev_pos = @pos
@pos += 1
Expand Down
8 changes: 4 additions & 4 deletions spec/truffle/caller_data_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def self.caller_binding_and_variables(last_line, last_match)

it "can have its special variables read and modified" do
last_line = "Hello!"
md = Primitive.matchdata_create(/o/, "Hello", [4], [5])
md = Primitive.matchdata_create_single_group(/o/, "Hello", 4, 5)
TruffleCallerSpecFixtures.last_line_set(last_line)
TruffleCallerSpecFixtures.last_match_set(md)
$_.should == last_line
Expand All @@ -48,7 +48,7 @@ def self.caller_binding_and_variables(last_line, last_match)

it "can have its special variables read and modified through an intermediate #send" do
last_line = "Hello!"
md = Primitive.matchdata_create(/o/, "Hello", [4], [5])
md = Primitive.matchdata_create_single_group(/o/, "Hello", 4, 5)
TruffleCallerSpecFixtures.send(:last_line_set, last_line)
TruffleCallerSpecFixtures.send(:last_match_set, md)
$_.should == last_line
Expand All @@ -57,7 +57,7 @@ def self.caller_binding_and_variables(last_line, last_match)

it "can have its special variables and frame read by the same method" do
last_line = "Hello!"
md = Primitive.matchdata_create(/o/, "Hello", [4], [5])
md = Primitive.matchdata_create_single_group(/o/, "Hello", 4, 5)
b = TruffleCallerSpecFixtures.caller_binding_and_variables(last_line, md)
$_.should == last_line
$~.should == md
Expand All @@ -66,7 +66,7 @@ def self.caller_binding_and_variables(last_line, last_match)

it "can have its special variables and frame read by the same method through an intermediate #send" do
last_line = "Hello!"
md = Primitive.matchdata_create(/o/, "Hello", [4], [5])
md = Primitive.matchdata_create_single_group(/o/, "Hello", 4, 5)
b = TruffleCallerSpecFixtures.send(:caller_binding_and_variables, last_line, md)
$_.should == last_line
$~.should == md
Expand Down
28 changes: 0 additions & 28 deletions src/main/java/org/truffleruby/core/regexp/MatchDataNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
import org.truffleruby.builtins.Primitive;
import org.truffleruby.builtins.PrimitiveArrayArgumentsNode;
import org.truffleruby.builtins.UnaryCoreMethodNode;
import org.truffleruby.core.array.ArrayIndexNodes;
import org.truffleruby.core.array.ArrayOperations;
import org.truffleruby.core.array.ArrayUtils;
import org.truffleruby.core.array.RubyArray;
import org.truffleruby.core.cast.IntegerCastNode;
import org.truffleruby.core.cast.ToIntNode;
import org.truffleruby.core.klass.RubyClass;
import org.truffleruby.core.range.RubyIntRange;
Expand Down Expand Up @@ -164,32 +162,6 @@ protected Object create(Object regexp, Object string, int start, int end) {

}

@Primitive(name = "matchdata_create")
public abstract static class MatchDataCreateNode extends PrimitiveArrayArgumentsNode {

@Specialization
protected Object create(Object regexp, Object string, RubyArray starts, RubyArray ends,
@Cached ArrayIndexNodes.ReadNormalizedNode readNode,
@Cached IntegerCastNode integerCastNode) {
final Region region = new Region(starts.size);
for (int i = 0; i < region.numRegs; i++) {
region.beg[i] = integerCastNode.executeCastInt(readNode.executeRead(starts, i));
region.end[i] = integerCastNode.executeCastInt(readNode.executeRead(ends, i));
}

RubyMatchData matchData = new RubyMatchData(
coreLibrary().matchDataClass,
getLanguage().matchDataShape,
regexp,
string,
region,
null);
AllocationTracing.trace(matchData, this);
return matchData;
}

}

@CoreMethod(
names = "[]",
required = 1,
Expand Down
4 changes: 1 addition & 3 deletions src/main/java/org/truffleruby/core/regexp/TRegexCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.jcodings.specific.USASCIIEncoding;
import org.jcodings.specific.UTF8Encoding;
import org.truffleruby.RubyContext;
import org.truffleruby.core.encoding.RubyEncoding;
import org.truffleruby.core.rope.CannotConvertBinaryRubyStringToJavaString;
import org.truffleruby.core.rope.Rope;
import org.truffleruby.core.rope.RopeBuilder;
Expand Down Expand Up @@ -60,8 +59,7 @@ public Object getBinaryRegex(boolean atStart) {
}

@TruffleBoundary
public Object compile(RubyContext context, RubyRegexp regexp, boolean atStart, RubyEncoding rubyEncoding) {
final Encoding encoding = rubyEncoding.encoding;
public Object compile(RubyContext context, RubyRegexp regexp, boolean atStart, Encoding encoding) {
final Object tregex = compileTRegex(context, regexp, atStart, encoding);
if (tregex == null) {
return Nil.INSTANCE;
Expand Down
200 changes: 170 additions & 30 deletions src/main/java/org/truffleruby/core/regexp/TruffleRegexpNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.interop.InteropException;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.LoopNode;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.IntValueProfile;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import org.jcodings.Encoding;
import org.jcodings.specific.ASCIIEncoding;
import org.jcodings.specific.USASCIIEncoding;
Expand All @@ -38,8 +44,6 @@
import org.truffleruby.core.array.ArrayBuilderNode;
import org.truffleruby.core.array.ArrayBuilderNode.BuilderState;
import org.truffleruby.core.array.RubyArray;
import org.truffleruby.core.encoding.EncodingNodes;
import org.truffleruby.core.encoding.RubyEncoding;
import org.truffleruby.core.encoding.StandardEncodings;
import org.truffleruby.core.kernel.KernelNodes.SameOrEqualNode;
import org.truffleruby.core.regexp.RegexpNodes.ToSNode;
Expand All @@ -53,6 +57,8 @@
import org.truffleruby.core.string.StringNodes;
import org.truffleruby.core.string.StringNodes.StringAppendPrimitiveNode;
import org.truffleruby.core.string.StringOperations;
import org.truffleruby.interop.TranslateInteropExceptionNode;
import org.truffleruby.interop.TranslateInteropExceptionNodeGen;
import org.truffleruby.language.RubyContextNode;
import org.truffleruby.language.control.DeferredRaiseException;
import org.truffleruby.language.dispatch.DispatchNode;
Expand Down Expand Up @@ -235,25 +241,18 @@ public RubyRegexp createRegexp(Rope pattern) throws DeferredRaiseException {
}
}

@CoreMethod(names = "select_encoding", onSingleton = true, required = 2)
public abstract static class SelectEncodingNode extends CoreMethodArrayArgumentsNode {

@Specialization(guards = "libString.isRubyString(str)")
protected RubyEncoding selectEncoding(RubyRegexp re, Object str,
@Cached EncodingNodes.GetRubyEncodingNode getRubyEncodingNode,
@Cached CheckEncodingNode checkEncodingNode,
@CachedLibrary(limit = "2") RubyStringLibrary libString) {
final Encoding encoding = checkEncodingNode.executeCheckEncoding(re, str);
return getRubyEncodingNode.executeGetRubyEncoding(encoding);
}
}

@ImportStatic(StandardEncodings.class)
@CoreMethod(names = "tregex_compile", onSingleton = true, required = 3)
public abstract static class TRegexCompileNode extends CoreMethodArrayArgumentsNode {

@Specialization(guards = "encoding.encoding == US_ASCII")
protected Object usASCII(RubyRegexp regexp, boolean atStart, RubyEncoding encoding) {
public static TRegexCompileNode create() {
return TruffleRegexpNodesFactory.TRegexCompileNodeFactory.create(null);
}

public abstract Object executeTRegexCompile(RubyRegexp regexp, boolean atStart, Encoding encoding);

@Specialization(guards = "encoding == US_ASCII")
protected Object usASCII(RubyRegexp regexp, boolean atStart, Encoding encoding) {
final Object tregex = regexp.tregexCache.getUSASCIIRegex(atStart);
if (tregex != null) {
return tregex;
Expand All @@ -262,8 +261,8 @@ protected Object usASCII(RubyRegexp regexp, boolean atStart, RubyEncoding encodi
}
}

@Specialization(guards = "encoding.encoding == ISO_8859_1")
protected Object latin1(RubyRegexp regexp, boolean atStart, RubyEncoding encoding) {
@Specialization(guards = "encoding == ISO_8859_1")
protected Object latin1(RubyRegexp regexp, boolean atStart, Encoding encoding) {
final Object tregex = regexp.tregexCache.getLatin1Regex(atStart);
if (tregex != null) {
return tregex;
Expand All @@ -272,8 +271,8 @@ protected Object latin1(RubyRegexp regexp, boolean atStart, RubyEncoding encodin
}
}

@Specialization(guards = "encoding.encoding == UTF_8")
protected Object utf8(RubyRegexp regexp, boolean atStart, RubyEncoding encoding) {
@Specialization(guards = "encoding == UTF_8")
protected Object utf8(RubyRegexp regexp, boolean atStart, Encoding encoding) {
final Object tregex = regexp.tregexCache.getUTF8Regex(atStart);
if (tregex != null) {
return tregex;
Expand All @@ -282,8 +281,8 @@ protected Object utf8(RubyRegexp regexp, boolean atStart, RubyEncoding encoding)
}
}

@Specialization(guards = "encoding.encoding == BINARY")
protected Object binary(RubyRegexp regexp, boolean atStart, RubyEncoding encoding) {
@Specialization(guards = "encoding == BINARY")
protected Object binary(RubyRegexp regexp, boolean atStart, Encoding encoding) {
final Object tregex = regexp.tregexCache.getBinaryRegex(atStart);
if (tregex != null) {
return tregex;
Expand All @@ -294,11 +293,11 @@ protected Object binary(RubyRegexp regexp, boolean atStart, RubyEncoding encodin

@Specialization(
guards = {
"encoding.encoding != US_ASCII",
"encoding.encoding != ISO_8859_1",
"encoding.encoding != UTF_8",
"encoding.encoding != BINARY" })
protected Object other(RubyRegexp regexp, boolean atStart, RubyEncoding encoding) {
"encoding != US_ASCII",
"encoding != ISO_8859_1",
"encoding != UTF_8",
"encoding != BINARY" })
protected Object other(RubyRegexp regexp, boolean atStart, Encoding encoding) {
return nil;
}

Expand Down Expand Up @@ -362,21 +361,28 @@ protected boolean initialized(RubyRegexp regexp) {
@Primitive(name = "regexp_match_in_region", lowerFixnum = { 2, 3, 5 })
public abstract static class MatchInRegionNode extends PrimitiveArrayArgumentsNode {

public static MatchInRegionNode create() {
return TruffleRegexpNodesFactory.MatchInRegionNodeFactory.create(null);
}

public abstract Object executeMatchInRegion(RubyRegexp regexp, Object string, int fromPos, int toPos,
boolean atStart, int startPos);

/** Matches a regular expression against a string over the specified range of characters.
*
* @param regexp The regexp to match
*
* @param string The string to match against
*
* @param fromPos The poistion to search from
* @param fromPos The position to search from
*
* @param toPos The position to search to (if less than from pos then this means search backwards)
*
* @param atStart Whether to only match at the beginning of the string, if false then the regexp can have any
* amount of prematch.
*
* @param startPos The position within the string which the matcher should consider the start. Setting this to
* the from position allows scanners to match starting partway through a string while still setting
* the from position allows scanners to match starting part-way through a string while still setting
* atStart and thus forcing the match to be at the specific starting position. */
@Specialization(guards = "libString.isRubyString(string)")
protected Object matchInRegion(
Expand All @@ -400,6 +406,140 @@ protected Object matchInRegion(
}
}

@Primitive(name = "regexp_match_in_region_tregex", lowerFixnum = { 2, 3, 5 })
public abstract static class MatchInRegionTRegexNode extends PrimitiveArrayArgumentsNode {

@Child MatchInRegionNode fallbackMatchInRegionNode;
@Child DispatchNode warnOnFallbackNode;

@Child DispatchNode stringDupNode;
@Child TranslateInteropExceptionNode translateInteropExceptionNode;

@Specialization(guards = "libString.isRubyString(string)")
protected Object matchInRegionTRegex(
RubyRegexp regexp, Object string, int fromPos, int toPos, boolean atStart, int startPos,
@Cached ConditionProfile matchFoundProfile,
@Cached ConditionProfile tRegexCouldNotCompileProfile,
@Cached ConditionProfile tRegexIncompatibleProfile,
@Cached LoopConditionProfile loopProfile,
@CachedLibrary(limit = "getInteropCacheLimit()") InteropLibrary regexInterop,
@CachedLibrary(limit = "getInteropCacheLimit()") InteropLibrary resultInterop,
@Cached RopeNodes.BytesNode bytesNode,
@Cached CheckEncodingNode checkEncodingNode,
@Cached TRegexCompileNode tRegexCompileNode,
@CachedLibrary(limit = "2") RubyStringLibrary libString,
@Cached("createIdentityProfile()") IntValueProfile groupCountProfile) {
final Rope rope = libString.getRope(string);
final Object tRegex;

if (tRegexIncompatibleProfile
.profile(toPos < fromPos || toPos != rope.byteLength() || startPos != 0 || fromPos < 0) ||
tRegexCouldNotCompileProfile.profile((tRegex = tRegexCompileNode.executeTRegexCompile(
regexp,
atStart,
checkEncodingNode.executeCheckEncoding(regexp, string))) == nil)) {
return fallbackToJoni(regexp, string, fromPos, toPos, atStart, startPos);
}

final byte[] bytes = bytesNode.execute(rope);
final Object interopByteArray = getContext().getEnv().asGuestValue(bytes);
final Object result = invoke(regexInterop, tRegex, "execBytes", interopByteArray, fromPos);

final boolean isMatch = (boolean) readMember(resultInterop, result, "isMatch");

if (matchFoundProfile.profile(isMatch)) {
final int groupCount = groupCountProfile.profile((int) readMember(regexInterop, tRegex, "groupCount"));
final Region region = new Region(groupCount);

loopProfile.profileCounted(groupCount);
try {
for (int group = 0; loopProfile.inject(group < groupCount); group++) {
region.beg[group] = (int) invoke(resultInterop, result, "getStart", new Object[]{ group });
region.end[group] = (int) invoke(resultInterop, result, "getEnd", new Object[]{ group });
}
} finally {
LoopNode.reportLoopCount(this, groupCount);
}

return createMatchData(regexp, dupString(string), region);
} else {
return nil;
}
}

private Object fallbackToJoni(RubyRegexp regexp, Object string, int fromPos, int toPos, boolean atStart,
int startPos) {
if (getContext().getOptions().WARN_TRUFFLE_REGEX_FALLBACK) {
if (warnOnFallbackNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
warnOnFallbackNode = insert(DispatchNode.create());
}

warnOnFallbackNode.call(
getContext().getCoreLibrary().truffleRegexpOperationsModule,
"warn_fallback",
regexp,
string,
fromPos,
toPos,
atStart,
startPos);
}

if (fallbackMatchInRegionNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
fallbackMatchInRegionNode = insert(MatchInRegionNode.create());
}

return fallbackMatchInRegionNode.executeMatchInRegion(regexp, string, fromPos, toPos, atStart, startPos);
}

private Object createMatchData(RubyRegexp regexp, Object string, Region region) {
final RubyMatchData matchData = new RubyMatchData(
coreLibrary().matchDataClass,
getLanguage().matchDataShape,
regexp,
string,
region,
null);
AllocationTracing.trace(matchData, this);
return matchData;
}

private Object readMember(InteropLibrary interop, Object receiver, String name) {
try {
return interop.readMember(receiver, name);
} catch (InteropException e) {
if (translateInteropExceptionNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
translateInteropExceptionNode = insert(TranslateInteropExceptionNodeGen.create());
}
throw translateInteropExceptionNode.execute(e);
}
}

private Object invoke(InteropLibrary interop, Object receiver, String member, Object... args) {
try {
return interop.invokeMember(receiver, member, args);
} catch (InteropException e) {
if (translateInteropExceptionNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
translateInteropExceptionNode = insert(TranslateInteropExceptionNodeGen.create());
}
throw translateInteropExceptionNode.executeInInvokeMember(e, receiver, args);
}
}

private Object dupString(Object string) {
if (stringDupNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
stringDupNode = insert(DispatchNode.create());
}

return stringDupNode.call(string, "dup");
}
}

public abstract static class MatchNode extends RubyContextNode {

@Child private DispatchNode dupNode = DispatchNode.create();
Expand Down
Loading

0 comments on commit 6eaa842

Please sign in to comment.