Skip to content

Commit

Permalink
Generate constant regex for match with a constant pattern.
Browse files Browse the repository at this point in the history
For constraints like match(".*", X), where the pattern is a
string constant, we can avoid using the regex cache. Avoiding the
cache is going to perform better in all cases.

I introduced the new node RegexConstant which only appears immediately
below a MATCH or NOT_MATCH node.

The behavior with regards to invalid regexes has not changed and bad
regexes will not lead to program errors.
  • Loading branch information
strRM committed Oct 24, 2022
1 parent 9b9db6c commit 08a5fad
Show file tree
Hide file tree
Showing 15 changed files with 222 additions and 31 deletions.
56 changes: 38 additions & 18 deletions src/interpreter/Engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,33 +1023,53 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) {
COMPARE(GE, >=)

case BinaryConstraintOp::MATCH: {
RamDomain left = execute(shadow.getLhs(), ctxt);
bool result = false;
RamDomain right = execute(shadow.getRhs(), ctxt);
const std::string& pattern = getSymbolTable().decode(left);
const std::string& text = getSymbolTable().decode(right);
bool result = false;
try {
const std::regex& regex = regexCache.getOrCreate(pattern);
result = std::regex_match(text, regex);
} catch (...) {
std::cerr << "warning: wrong pattern provided for match(\"" << pattern << "\",\""
<< text << "\").\n";

const Node* patternNode = shadow.getLhs();
if (const RegexConstant* regexNode = dynamic_cast<const RegexConstant*>(patternNode);
regexNode) {
const auto& regex = regexNode->getRegex();
if (regex) {
result = std::regex_match(text, *regex);
}
} else {
RamDomain left = execute(patternNode, ctxt);
const std::string& pattern = getSymbolTable().decode(left);
try {
const std::regex& regex = regexCache.getOrCreate(pattern);
result = std::regex_match(text, regex);
} catch (...) {
std::cerr << "warning: wrong pattern provided for match(\"" << pattern << "\",\""
<< text << "\").\n";
}
}

return result;
}
case BinaryConstraintOp::NOT_MATCH: {
RamDomain left = execute(shadow.getLhs(), ctxt);
bool result = false;
RamDomain right = execute(shadow.getRhs(), ctxt);
const std::string& pattern = getSymbolTable().decode(left);
const std::string& text = getSymbolTable().decode(right);
bool result = false;
try {
const std::regex& regex = regexCache.getOrCreate(pattern);
result = !std::regex_match(text, regex);
} catch (...) {
std::cerr << "warning: wrong pattern provided for !match(\"" << pattern << "\",\""
<< text << "\").\n";

const Node* patternNode = shadow.getLhs();
if (const RegexConstant* regexNode = dynamic_cast<const RegexConstant*>(patternNode);
regexNode) {
const auto& regex = regexNode->getRegex();
if (regex) {
result = !std::regex_match(text, *regex);
}
} else {
RamDomain left = execute(patternNode, ctxt);
const std::string& pattern = getSymbolTable().decode(left);
try {
const std::regex& regex = regexCache.getOrCreate(pattern);
result = !std::regex_match(text, regex);
} catch (...) {
std::cerr << "warning: wrong pattern provided for !match(\"" << pattern << "\",\""
<< text << "\").\n";
}
}
return result;
}
Expand Down
24 changes: 23 additions & 1 deletion src/interpreter/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,29 @@ NodePtr NodeGenerator::visit_(
}

NodePtr NodeGenerator::visit_(type_identity<ram::Constraint>, const ram::Constraint& relOp) {
return mk<Constraint>(I_Constraint, &relOp, dispatch(relOp.getLHS()), dispatch(relOp.getRHS()));
auto left = dispatch(relOp.getLHS());
auto right = dispatch(relOp.getRHS());
switch (relOp.getOperator()) {
case BinaryConstraintOp::MATCH:
case BinaryConstraintOp::NOT_MATCH:
if (const StringConstant* str = dynamic_cast<const StringConstant*>(left.get()); str) {
const std::string& pattern = engine.getSymbolTable().unsafeDecode(str->getConstant());
try {
std::regex regex(pattern);
// treat the string constant as a regex
left = mk<RegexConstant>(*str, std::move(regex));
} catch (const std::exception&) {
std::cerr << "warning: wrong pattern provided \"" << pattern << "\"\n";

// we could not compile the pattern
left = mk<RegexConstant>(*str, std::nullopt);
}
}
break;
default: break;
}

return mk<Constraint>(I_Constraint, &relOp, std::move(left), std::move(right));
}

NodePtr NodeGenerator::visit_(type_identity<ram::NestedOperation>, const ram::NestedOperation& nested) {
Expand Down
17 changes: 17 additions & 0 deletions src/interpreter/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <cassert>
#include <cstddef>
#include <memory>
#include <regex>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -439,6 +440,22 @@ class StringConstant : public Node {
std::size_t constant;
};

/**
* @class RegexConstant
*/
class RegexConstant : public StringConstant {
public:
RegexConstant(const StringConstant& c, std::optional<std::regex> r)
: StringConstant(c.getType(), c.getShadow(), c.getConstant()), regex(std::move(r)) {}

inline const std::optional<std::regex>& getRegex() const {
return regex;
}

private:
const std::optional<std::regex> regex;
};

/**
* @class TupleElement
*/
Expand Down
82 changes: 70 additions & 12 deletions src/synthesiser/Synthesiser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,21 @@ ram::RelationSet Synthesiser::getReferencedRelations(const Operation& op) {
return res;
}

std::optional<std::size_t> Synthesiser::compileRegex(const std::string& pattern) {
auto i = regexes.find(pattern);
if (i != regexes.end()) {
return i->second;
}
try {
const std::regex regex(pattern);
std::size_t index = regexes.size();
return regexes.emplace(pattern, index).first->second;
} catch (const std::exception&) {
std::cerr << "warning: wrong pattern provided \"" << pattern << "\"\n";
return std::nullopt;
}
}

void Synthesiser::emitCode(std::ostream& out, const Statement& stmt) {
class CodeEmitter : public ram::Visitor<void, Node const, std::ostream&> {
using ram::Visitor<void, Node const, std::ostream&>::visit_;
Expand Down Expand Up @@ -1828,21 +1843,43 @@ void Synthesiser::emitCode(std::ostream& out, const Statement& stmt) {

// strings
case BinaryConstraintOp::MATCH: {
synthesiser.SubroutineUsingStdRegex = true;
out << "regex_wrapper(symTable.decode(";
dispatch(rel.getLHS(), out);
out << "),symTable.decode(";
dispatch(rel.getRHS(), out);
out << "))";
if (const StringConstant* str = dynamic_cast<const StringConstant*>(&rel.getLHS()); str) {
const auto& regex = synthesiser.compileRegex(str->getConstant());
if (regex) {
out << "std::regex_match(symTable.decode(";
dispatch(rel.getRHS(), out);
out << "), regexes.at(" << *regex << "))";
} else {
out << "false";
}
} else {
synthesiser.SubroutineUsingStdRegex = true;
out << "regex_wrapper(symTable.decode(";
dispatch(rel.getLHS(), out);
out << "),symTable.decode(";
dispatch(rel.getRHS(), out);
out << "))";
}
break;
}
case BinaryConstraintOp::NOT_MATCH: {
synthesiser.SubroutineUsingStdRegex = true;
out << "!regex_wrapper(symTable.decode(";
dispatch(rel.getLHS(), out);
out << "),symTable.decode(";
dispatch(rel.getRHS(), out);
out << "))";
if (const StringConstant* str = dynamic_cast<const StringConstant*>(&rel.getLHS()); str) {
const auto& regex = synthesiser.compileRegex(str->getConstant());
if (regex) {
out << "!std::regex_match(symTable.decode(";
dispatch(rel.getRHS(), out);
out << "), regexes.at(" << *regex << "))";
} else {
out << "false";
}
} else {
synthesiser.SubroutineUsingStdRegex = true;
out << "!regex_wrapper(symTable.decode(";
dispatch(rel.getLHS(), out);
out << "),symTable.decode(";
dispatch(rel.getRHS(), out);
out << "))";
}
break;
}
case BinaryConstraintOp::CONTAINS: {
Expand Down Expand Up @@ -2680,6 +2717,27 @@ void Synthesiser::generateCode(GenDb& db, const std::string& id, bool& withShare
<< " return result;\n";
}

if (!regexes.empty()) {
gen.addField("std::vector<std::regex>", "regexes", Visibility::Private);
std::stringstream rst;
// we need to collect the patterns first and place each
// one into the correct slot
std::vector<std::string> patterns;
patterns.resize(regexes.size());
for (const auto& pi : regexes) {
patterns.at(pi.second) = pi.first;
}
rst << "{\n";
for (const auto& p : patterns) {
const std::string escaped = escape(p);
rst << "\tstd::regex(\"" << escaped << "\"),\n";
}
rst << "}";

constructor.setNextInitializer("regexes", rst.str());
regexes.clear();
}

// substring wrapper
if (SubroutineUsingSubstr) {
GenFunction& wrapper = gen.addFunction("substr_wrapper", Visibility::Private);
Expand Down
10 changes: 10 additions & 0 deletions src/synthesiser/Synthesiser.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <map>
#include <memory>
#include <ostream>
#include <regex>
#include <set>
#include <string>

Expand Down Expand Up @@ -72,6 +73,12 @@ class Synthesiser {
bool SubroutineUsingStdRegex = false;
bool SubroutineUsingSubstr = false;

/** A mapping of valid regex patterns to to a unique index.
* The index to which the pattern is mapped is in
* the range from 0 regexes.size()-1.
*/
std::map<std::string, std::size_t> regexes;

/** Pointer to the subroutine class currently being built */
GenClass* currentClass = nullptr;

Expand Down Expand Up @@ -107,6 +114,9 @@ class Synthesiser {
/** Get referenced relations */
ram::RelationSet getReferencedRelations(const ram::Operation& op);

/** Compile a regular expression and return a unique name for it */
std::optional<std::size_t> compileRegex(const std::string& pattern);

/** Generate code */
void emitCode(std::ostream& out, const ram::Statement& stmt);

Expand Down
2 changes: 2 additions & 0 deletions tests/evaluation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ positive_test(aggregate_witnesses)
positive_test(aliases)
positive_test(arithm)
positive_test(average)
positive_test(bad_regex)
positive_test(binop)
positive_test(cat)
positive_test(choice_advisor)
Expand Down Expand Up @@ -120,6 +121,7 @@ positive_test(magic_turing1)
positive_test(match2)
positive_test(match3)
positive_test(match4)
positive_test(match5)
positive_test(match COMPILED_SPLITTED)
# TODO (see issue #298) positive_test(math)
positive_test(max)
Expand Down
1 change: 1 addition & 0 deletions tests/evaluation/bad_regex/B.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
foobar
8 changes: 8 additions & 0 deletions tests/evaluation/bad_regex/bad_regex.dl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.decl A(x:symbol)
.input A()

.decl B(x:symbol)
.output B()

B(x) :- A(x),match("fo{2}b.*",x).
B(x) :- A(x),!match("fo{x}b.*",x).
1 change: 1 addition & 0 deletions tests/evaluation/bad_regex/bad_regex.err
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
warning: wrong pattern provided "fo{x}b.*"
Empty file.
2 changes: 2 additions & 0 deletions tests/evaluation/bad_regex/facts/A.facts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
foobar
fooobar
44 changes: 44 additions & 0 deletions tests/evaluation/match5/match5.dl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Souffle - A Datalog Compiler
// Copyright (c) 2020, The Souffle Developers. All rights reserved
// Licensed under the Universal Permissive License v 1.0 as shown at:
// - https://opensource.org/licenses/UPL
// - <souffle root>/licenses/SOUFFLE-UPL.txt

// Check string matching from relations with mixed valid and invalid regex

.type String <: symbol


// Split into several forced strata to keep warnings deterministic
// Note: we add a "dummy" check to prevent optimisation
.decl dummy(x:String)
dummy("dummy").

.decl inputDataStep1, inputDataStep2, inputDataStep3(x:String)
inputDataStep1("a").
inputDataStep2("aaaa").
inputDataStep3("bdab").

.decl outputDataStep1(x:String)
outputDataStep1(x) :- inputDataStep1(x), match("a.*", x).
outputDataStep1(x) :- inputDataStep1(x), match("b.*", x).
outputDataStep1(x) :- inputDataStep1(x), match("b.*[", x).
outputDataStep1(x) :- inputDataStep1(x), match("a.*[a]", x).

.decl outputDataStep2(x:String)
outputDataStep2(x) :- dummy(x), !outputDataStep1(x), x != "dummy".
outputDataStep2(x) :- inputDataStep2(x), match("a.*", x).
outputDataStep2(x) :- inputDataStep2(x), match("b.*", x).
outputDataStep2(x) :- inputDataStep2(x), match("b.*[", x).
outputDataStep2(x) :- inputDataStep2(x), match("a.*[a]", x).

.decl outputDataStep3(x:String)
outputDataStep3(x) :- dummy(x), !outputDataStep2(x), x != "dummy".
outputDataStep3(x) :- inputDataStep3(x), match("a.*", x).
outputDataStep3(x) :- inputDataStep3(x), match("b.*", x).
outputDataStep3(x) :- inputDataStep3(x), match("b.*[", x).
outputDataStep3(x) :- inputDataStep3(x), match("a.*[a]", x).

.decl outputData(x:String)
.output outputData()
outputData(x) :- outputDataStep1(x) ; outputDataStep2(x) ; outputDataStep3(x).
3 changes: 3 additions & 0 deletions tests/evaluation/match5/match5.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
warning: wrong pattern provided "b.*["
warning: wrong pattern provided "b.*["
warning: wrong pattern provided "b.*["
Empty file.
3 changes: 3 additions & 0 deletions tests/evaluation/match5/outputData.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
a
aaaa
bdab

0 comments on commit 08a5fad

Please sign in to comment.