Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRegex splitting #2389

Merged
merged 3 commits into from
Jun 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions src/main/java/org/truffleruby/core/regexp/MatchDataNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,56 @@ protected Object create(Object regexp, Object string, int start, int end) {
@Primitive(name = "matchdata_create")
public abstract static class MatchDataCreateNode extends PrimitiveArrayArgumentsNode {

public static MatchDataCreateNode create() {
return MatchDataNodesFactory.MatchDataCreateNodeFactory.create(null);
}

public abstract Object executeMatchDataCreate(RubyRegexp regexp, Object string, Object starts, Object ends);
eregon marked this conversation as resolved.
Show resolved Hide resolved

@Specialization
protected Object create(Object regexp, Object string, RubyArray starts, RubyArray ends,
protected Object matchDataCreate(RubyRegexp regexp, Object string, RubyArray starts, RubyArray ends,
eregon marked this conversation as resolved.
Show resolved Hide resolved
@Cached LoopConditionProfile loopProfile,
@Cached ArrayIndexNodes.ReadNormalizedNode readNode,
@Cached IntegerCastNode integerCastNode) {
final Region region = new Region(starts.size);
for (int i = 0; i < region.numRegs; i++) {
region.beg[i] = integerCastNode.executeCastInt(readNode.executeRead(starts, i));
region.end[i] = integerCastNode.executeCastInt(readNode.executeRead(ends, i));

try {
loopProfile.profileCounted(region.numRegs);

for (int i = 0; loopProfile.inject(i < region.numRegs); i++) {
region.beg[i] = integerCastNode.executeCastInt(readNode.executeRead(starts, i));
region.end[i] = integerCastNode.executeCastInt(readNode.executeRead(ends, i));
}
} finally {
LoopNode.reportLoopCount(this, region.numRegs);
}

RubyMatchData matchData = new RubyMatchData(
return createMatchData(regexp, string, region);
}

@Specialization
protected Object create(RubyRegexp regexp, Object string, int[] starts, int[] ends,
@Cached LoopConditionProfile loopProfile,
@Cached ArrayIndexNodes.ReadNormalizedNode readNode,
@Cached IntegerCastNode integerCastNode) {
final Region region = new Region(starts.length);

try {
loopProfile.profileCounted(region.numRegs);

for (int i = 0; loopProfile.inject(i < region.numRegs); i++) {
region.beg[i] = starts[i];
region.end[i] = ends[i];
}
} finally {
LoopNode.reportLoopCount(this, region.numRegs);
}

return createMatchData(regexp, string, region);
}

private Object createMatchData(RubyRegexp regexp, Object string, Region region) {
final RubyMatchData matchData = new RubyMatchData(
coreLibrary().matchDataClass,
getLanguage().matchDataShape,
regexp,
Expand Down
162 changes: 158 additions & 4 deletions src/main/java/org/truffleruby/core/regexp/TruffleRegexpNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.LoopNode;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import org.jcodings.Encoding;
import org.jcodings.specific.ASCIIEncoding;
import org.jcodings.specific.USASCIIEncoding;
Expand Down Expand Up @@ -238,12 +241,18 @@ public RubyRegexp createRegexp(Rope pattern) throws DeferredRaiseException {
@CoreMethod(names = "select_encoding", onSingleton = true, required = 2)
public abstract static class SelectEncodingNode extends CoreMethodArrayArgumentsNode {

public static SelectEncodingNode create() {
return TruffleRegexpNodesFactory.SelectEncodingNodeFactory.create(null);
}

public abstract RubyEncoding executeSelectEncoding(RubyRegexp regexp, Object str);

@Specialization(guards = "libString.isRubyString(str)")
protected RubyEncoding selectEncoding(RubyRegexp re, Object str,
protected RubyEncoding selectEncoding(RubyRegexp regexp, Object str,
@Cached EncodingNodes.GetRubyEncodingNode getRubyEncodingNode,
@Cached CheckEncodingNode checkEncodingNode,
@CachedLibrary(limit = "2") RubyStringLibrary libString) {
final Encoding encoding = checkEncodingNode.executeCheckEncoding(re, str);
final Encoding encoding = checkEncodingNode.executeCheckEncoding(regexp, str);
return getRubyEncodingNode.executeGetRubyEncoding(encoding);
}
}
Expand All @@ -252,6 +261,12 @@ protected RubyEncoding selectEncoding(RubyRegexp re, Object str,
@CoreMethod(names = "tregex_compile", onSingleton = true, required = 3)
public abstract static class TRegexCompileNode extends CoreMethodArrayArgumentsNode {

public static TRegexCompileNode create() {
return TruffleRegexpNodesFactory.TRegexCompileNodeFactory.create(null);
}

public abstract Object executeTRegexCompile(RubyRegexp regexp, boolean atStart, RubyEncoding encoding);

@Specialization(guards = "encoding.encoding == US_ASCII")
protected Object usASCII(RubyRegexp regexp, boolean atStart, RubyEncoding encoding) {
final Object tregex = regexp.tregexCache.getUSASCIIRegex(atStart);
Expand Down Expand Up @@ -362,21 +377,28 @@ protected boolean initialized(RubyRegexp regexp) {
@Primitive(name = "regexp_match_in_region", lowerFixnum = { 2, 3, 5 })
public abstract static class MatchInRegionNode extends PrimitiveArrayArgumentsNode {

public static MatchInRegionNode create() {
return TruffleRegexpNodesFactory.MatchInRegionNodeFactory.create(null);
}

public abstract Object executeMatchInRegion(RubyRegexp regexp, Object string, int fromPos, int toPos,
boolean atStart, int startPos);

/** Matches a regular expression against a string over the specified range of characters.
*
* @param regexp The regexp to match
*
* @param string The string to match against
*
* @param fromPos The poistion to search from
* @param fromPos The position to search from
*
* @param toPos The position to search to (if less than from pos then this means search backwards)
*
* @param atStart Whether to only match at the beginning of the string, if false then the regexp can have any
* amount of prematch.
*
* @param startPos The position within the string which the matcher should consider the start. Setting this to
* the from position allows scanners to match starting partway through a string while still setting
* the from position allows scanners to match starting part-way through a string while still setting
* atStart and thus forcing the match to be at the specific starting position. */
@Specialization(guards = "libString.isRubyString(string)")
protected Object matchInRegion(
Expand All @@ -400,6 +422,138 @@ protected Object matchInRegion(
}
}

@Primitive(name = "regexp_match_in_region_tregex", lowerFixnum = { 2, 3, 5 })
public abstract static class MatchInRegionTRegexNode extends PrimitiveArrayArgumentsNode {

@Child MatchInRegionNode fallbackMatchInRegionNode;
@Child DispatchNode warnOnFallbackNode;

@Child DispatchNode stringDupNode;
@Child DispatchNode tRegexGroupCountNode;
@Child DispatchNode tRegexGetStartNode;
@Child DispatchNode tRegexGetEndNode;

@Specialization(guards = "libString.isRubyString(string)")
protected Object matchInRegionTRegex(
RubyRegexp regexp, Object string, int fromPos, int toPos, boolean atStart, int startPos,
@Cached ConditionProfile matchFoundProfile,
@Cached ConditionProfile tRegexCouldNotCompileProfile,
@Cached ConditionProfile tRegexIncompatibleProfile,
@Cached LoopConditionProfile loopProfile,
@Cached DispatchNode tRegexExecBytesNode,
@Cached DispatchNode tRegexIsMatchNode,
@Cached RopeNodes.BytesNode bytesNode,
@Cached MatchDataNodes.MatchDataCreateNode matchDataCreateNode,
@Cached SelectEncodingNode selectEncodingNode,
@Cached TRegexCompileNode tRegexCompileNode,
@CachedLibrary(limit = "2") RubyStringLibrary libString) {
final Rope rope = libString.getRope(string);

if (tRegexIncompatibleProfile
.profile(toPos < fromPos || toPos != rope.byteLength() || startPos != 0 || fromPos < 0)) {
return fallbackToJoni(regexp, string, fromPos, toPos, atStart, startPos);
} else {
final RubyEncoding encoding = selectEncodingNode.executeSelectEncoding(regexp, string);
final Object compiledRegex = tRegexCompileNode.executeTRegexCompile(regexp, atStart, encoding);

if (tRegexCouldNotCompileProfile.profile(compiledRegex == nil)) {
return fallbackToJoni(regexp, string, fromPos, toPos, atStart, startPos);
eregon marked this conversation as resolved.
Show resolved Hide resolved
}

final byte[] bytes = bytesNode.execute(rope);
final Object regexResult = tRegexExecBytesNode
.call(compiledRegex, "execBytes", getContext().getEnv().asGuestValue(bytes), fromPos);
eregon marked this conversation as resolved.
Show resolved Hide resolved

final boolean isMatch = (boolean) tRegexIsMatchNode.call(regexResult, "isMatch");

if (matchFoundProfile.profile(isMatch)) {
final int groupCount = tRegexGroupCount(compiledRegex);
final int[] starts = new int[groupCount];
final int[] ends = new int[groupCount];

try {
loopProfile.profileCounted(groupCount);

for (int pos = 0; loopProfile.inject(pos < groupCount); pos++) {
starts[pos] = tRegexGetStart(regexResult, pos);
ends[pos] = tRegexGetEnd(regexResult, pos);
}
} finally {
LoopNode.reportLoopCount(this, groupCount);
}

return matchDataCreateNode
.executeMatchDataCreate(regexp, dupString(string), starts, ends);
eregon marked this conversation as resolved.
Show resolved Hide resolved
} else {
return nil;
}
}
}

private Object fallbackToJoni(RubyRegexp regexp, Object string, int fromPos, int toPos, boolean atStart,
int startPos) {
if (getContext().getOptions().WARN_TRUFFLE_REGEX_FALLBACK) {
if (warnOnFallbackNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
warnOnFallbackNode = insert(DispatchNode.create());
}

warnOnFallbackNode.call(
getContext().getCoreLibrary().truffleRegexpOperationsModule,
"warn_fallback",
regexp,
string,
fromPos,
toPos,
atStart,
startPos);
}

if (fallbackMatchInRegionNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
fallbackMatchInRegionNode = insert(MatchInRegionNode.create());
}

return fallbackMatchInRegionNode.executeMatchInRegion(regexp, string, fromPos, toPos, atStart, startPos);
}

private int tRegexGroupCount(Object compiledRegex) {
if (tRegexGroupCountNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
tRegexGroupCountNode = insert(DispatchNode.create());
}

return (int) tRegexGroupCountNode.call(compiledRegex, "groupCount");
}

private int tRegexGetStart(Object regexResult, int pos) {
if (tRegexGetStartNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
tRegexGetStartNode = insert(DispatchNode.create());
}

return (int) tRegexGetStartNode.call(regexResult, "getStart", pos);
}

private int tRegexGetEnd(Object regexResult, int pos) {
if (tRegexGetEndNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
tRegexGetEndNode = insert(DispatchNode.create());
}

return (int) tRegexGetEndNode.call(regexResult, "getEnd", pos);
}

private Object dupString(Object string) {
if (stringDupNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
stringDupNode = insert(DispatchNode.create());
}

return stringDupNode.call(string, "dup");
}
}

public abstract static class MatchNode extends RubyContextNode {

@Child private DispatchNode dupNode = DispatchNode.create();
Expand Down
13 changes: 0 additions & 13 deletions src/main/java/org/truffleruby/core/string/TruffleStringNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
import org.truffleruby.core.rope.Rope;
import org.truffleruby.core.rope.RopeNodes;
import org.truffleruby.language.control.RaiseException;
import org.truffleruby.language.library.RubyStringLibrary;

import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.library.CachedLibrary;

@CoreModule("Truffle::StringOperations")
public class TruffleStringNodes {
Expand Down Expand Up @@ -71,15 +69,4 @@ private String formatTooLongError(int count, final Rope rope) {
}

}

@CoreMethod(names = "raw_bytes", onSingleton = true, required = 1)
public abstract static class RawBytesNode extends CoreMethodArrayArgumentsNode {
@Specialization(guards = "libString.isRubyString(string)")
protected Object rawBytes(Object string,
@CachedLibrary(limit = "2") RubyStringLibrary libString,
@Cached RopeNodes.BytesNode bytesNode) {
byte[] bytes = bytesNode.execute(libString.getRope(string));
return getContext().getEnv().asGuestValue(bytes);
}
}
}
27 changes: 2 additions & 25 deletions src/main/ruby/truffleruby/core/truffle/regexp_operations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def self.match_in_region(re, str, from, to, at_start, start)
if COMPARE_ENGINES
match_in_region_compare_engines(re, str, from, to, at_start, start)
elsif USE_TRUFFLE_REGEX
match_in_region_tregex(re, str, from, to, at_start, start)
Primitive.regexp_match_in_region_tregex(re, str, from, to, at_start, start)
else
Primitive.regexp_match_in_region(re, str, from, to, at_start, start)
end
end

def self.match_in_region_compare_engines(re, str, from, to, at_start, start)
begin
md1 = match_in_region_tregex(re, str, from, to, at_start, start)
md1 = Primitive.regexp_match_in_region_tregex(re, str, from, to, at_start, start)
rescue => e
md1 = e
end
Expand All @@ -105,29 +105,6 @@ def self.match_in_region_compare_engines(re, str, from, to, at_start, start)
end
end

def self.match_in_region_tregex(re, str, from, to, at_start, start)
if to < from || to != str.bytesize || start != 0 || from < 0 ||
Primitive.nil?((compiled_regex = tregex_compile(re, at_start, select_encoding(re, str))))
warn_fallback(re, str, from, to, at_start, start) if WARN_TRUFFLE_REGEX_FALLBACK
return Primitive.regexp_match_in_region(re, str, from, to, at_start, start)
end

str_bytes = StringOperations.raw_bytes(str)
regex_result = compiled_regex.execBytes(str_bytes, from)

if regex_result.isMatch
starts = []
ends = []
compiled_regex.groupCount.times do |pos|
starts << regex_result.getStart(pos)
ends << regex_result.getEnd(pos)
end
Primitive.matchdata_create(re, str.dup, starts, ends)
else
nil
end
end

def self.warn_fallback(re, str, from, to, at_start, start)
warn match_args_to_string(re, str, from, to, at_start, start, 'cannot be run as a Truffle regexp and fell back to Joni'), uplevel: 1
end
Expand Down