Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C++] Drastically improve multi-threaded performance #3550

Merged
merged 1 commit into from
Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public:
virtual std::unique_ptr\<antlr4::Token> nextToken() override {
if (dynamic_cast\<PositionAdjustingLexerATNSimulator *>(_interpreter) == nullptr) {
delete _interpreter;
_interpreter = new PositionAdjustingLexerATNSimulator(this, _atn, _decisionToDFA, _sharedContextCache);
_interpreter = new PositionAdjustingLexerATNSimulator(this, *_atn, *_decisionToDFA, _sharedContextCache);
}

return antlr4::Lexer::nextToken();
Expand Down
46 changes: 33 additions & 13 deletions runtime/Cpp/runtime/src/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,25 @@

using namespace antlr4;
using namespace antlr4::atn;

using namespace antlrcpp;

std::map<std::vector<uint16_t>, atn::ATN> Parser::bypassAltsAtnCache;
namespace {

struct BypassAltsAtnCache final {
std::shared_mutex mutex;
/// This field maps from the serialized ATN string to the deserialized <seealso cref="ATN"/> with
/// bypass alternatives.
///
/// <seealso cref= ATNDeserializationOptions#isGenerateRuleBypassTransitions() </seealso>
std::map<std::vector<uint16_t>, std::unique_ptr<const atn::ATN>> map;
};

BypassAltsAtnCache* getBypassAltsAtnCache() {
static BypassAltsAtnCache* const instance = new BypassAltsAtnCache();
return instance;
}

}

Parser::TraceListener::TraceListener(Parser *outerInstance_) : outerInstance(outerInstance_) {
}
Expand Down Expand Up @@ -214,25 +229,30 @@ TokenFactory<CommonToken>* Parser::getTokenFactory() {


const atn::ATN& Parser::getATNWithBypassAlts() {
std::vector<uint16_t> serializedAtn = getSerializedATN();
const std::vector<uint16_t> &serializedAtn = getSerializedATN();
if (serializedAtn.empty()) {
throw UnsupportedOperationException("The current parser does not support an ATN with bypass alternatives.");
}

std::lock_guard<std::mutex> lck(_mutex);

// XXX: using the entire serialized ATN as key into the map is a big resource waste.
// How large can that thing become?
if (bypassAltsAtnCache.find(serializedAtn) == bypassAltsAtnCache.end())
auto *cache = getBypassAltsAtnCache();
{
jcking marked this conversation as resolved.
Show resolved Hide resolved
atn::ATNDeserializationOptions deserializationOptions;
deserializationOptions.setGenerateRuleBypassTransitions(true);

atn::ATNDeserializer deserializer(deserializationOptions);
bypassAltsAtnCache[serializedAtn] = deserializer.deserialize(serializedAtn);
std::shared_lock<std::shared_mutex> lock(cache->mutex);
auto existing = cache->map.find(serializedAtn);
if (existing != cache->map.end()) {
return *existing->second;
}
}

return bypassAltsAtnCache[serializedAtn];
atn::ATNDeserializationOptions deserializationOptions;
deserializationOptions.setGenerateRuleBypassTransitions(true);
atn::ATNDeserializer deserializer(deserializationOptions);
auto atn = deserializer.deserialize(serializedAtn);

{
std::unique_lock<std::shared_mutex> lock(cache->mutex);
return *cache->map.insert(std::make_pair(serializedAtn, std::move(atn))).first->second;
}
}

tree::pattern::ParseTreePattern Parser::compileParseTreePattern(const std::string &pattern, int patternRuleIndex) {
Expand Down
6 changes: 0 additions & 6 deletions runtime/Cpp/runtime/src/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,6 @@ namespace antlr4 {
tree::ParseTreeTracker _tracker;

private:
/// This field maps from the serialized ATN string to the deserialized <seealso cref="ATN"/> with
/// bypass alternatives.
///
/// <seealso cref= ATNDeserializationOptions#isGenerateRuleBypassTransitions() </seealso>
static std::map<std::vector<uint16_t>, atn::ATN> bypassAltsAtnCache;

/// When setTrace(true) is called, a reference to the
/// TraceListener is stored here so it can be easily removed in a
/// later call to setTrace(false). The listener itself is
Expand Down
2 changes: 1 addition & 1 deletion runtime/Cpp/runtime/src/Recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace antlr4 {
/// For interpreters, we don't know their serialized ATN despite having
/// created the interpreter from it.
/// </summary>
virtual const std::vector<uint16_t> getSerializedATN() const {
virtual const std::vector<uint16_t>& getSerializedATN() const {
throw "there is no serialized ATN";
}

Expand Down
57 changes: 3 additions & 54 deletions runtime/Cpp/runtime/src/atn/ATN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,67 +20,16 @@ using namespace antlr4;
using namespace antlr4::atn;
using namespace antlrcpp;

ATN::ATN() : ATN(ATNType::LEXER, 0) {
}
ATN::ATN() : ATN(ATNType::LEXER, 0) {}

ATN::ATN(ATN &&other) {
// All source vectors are implicitly cleared by the moves.
states = std::move(other.states);
decisionToState = std::move(other.decisionToState);
ruleToStartState = std::move(other.ruleToStartState);
ruleToStopState = std::move(other.ruleToStopState);
grammarType = std::move(other.grammarType);
maxTokenType = std::move(other.maxTokenType);
ruleToTokenType = std::move(other.ruleToTokenType);
lexerActions = std::move(other.lexerActions);
modeToStartState = std::move(other.modeToStartState);
}

ATN::ATN(ATNType grammarType_, size_t maxTokenType_) : grammarType(grammarType_), maxTokenType(maxTokenType_) {
}
ATN::ATN(ATNType grammarType_, size_t maxTokenType_) : grammarType(grammarType_), maxTokenType(maxTokenType_) {}

ATN::~ATN() {
for (ATNState *state : states) {
delete state;
}
}

/**
* Required to be defined (even though not used) as we have an explicit move assignment operator.
*/
ATN& ATN::operator = (ATN &other) noexcept {
states = other.states;
decisionToState = other.decisionToState;
ruleToStartState = other.ruleToStartState;
ruleToStopState = other.ruleToStopState;
grammarType = other.grammarType;
maxTokenType = other.maxTokenType;
ruleToTokenType = other.ruleToTokenType;
lexerActions = other.lexerActions;
modeToStartState = other.modeToStartState;

return *this;
}

/**
* Explicit move assignment operator to make this the preferred assignment. With implicit copy/move assignment
* operators it seems the copy operator is preferred causing trouble when releasing the allocated ATNState instances.
*/
ATN& ATN::operator = (ATN &&other) noexcept {
// All source vectors are implicitly cleared by the moves.
states = std::move(other.states);
decisionToState = std::move(other.decisionToState);
ruleToStartState = std::move(other.ruleToStartState);
ruleToStopState = std::move(other.ruleToStopState);
grammarType = std::move(other.grammarType);
maxTokenType = std::move(other.maxTokenType);
ruleToTokenType = std::move(other.ruleToTokenType);
lexerActions = std::move(other.lexerActions);
modeToStartState = std::move(other.modeToStartState);

return *this;
}

misc::IntervalSet ATN::nextTokens(ATNState *s, RuleContext *ctx) const {
LL1Analyzer analyzer(*this);
return analyzer.LOOK(s, ctx);
Expand All @@ -89,7 +38,7 @@ misc::IntervalSet ATN::nextTokens(ATNState *s, RuleContext *ctx) const {

misc::IntervalSet const& ATN::nextTokens(ATNState *s) const {
if (!s->_nextTokenUpdated) {
std::unique_lock<std::mutex> lock { _mutex };
std::unique_lock<std::mutex> lock(_mutex);
if (!s->_nextTokenUpdated) {
s->_nextTokenWithinRule = nextTokens(s, nullptr);
s->_nextTokenUpdated = true;
Expand Down
40 changes: 27 additions & 13 deletions runtime/Cpp/runtime/src/atn/ATN.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,27 @@
namespace antlr4 {
namespace atn {

class LexerATNSimulator;
class ParserATNSimulator;

class ANTLR4CPP_PUBLIC ATN {
public:
static constexpr size_t INVALID_ALT_NUMBER = 0;

/// Used for runtime deserialization of ATNs from strings.
ATN();
ATN(ATN &&other);

ATN(ATNType grammarType, size_t maxTokenType);
virtual ~ATN();

ATN(const ATN&) = delete;

ATN(ATN&&) = delete;

~ATN();

ATN& operator=(const ATN&) = delete;

ATN& operator=(ATN&&) = delete;

std::vector<ATNState *> states;

Expand Down Expand Up @@ -60,33 +72,30 @@ namespace atn {

std::vector<TokensStartState *> modeToStartState;

ATN& operator = (ATN &other) noexcept;
ATN& operator = (ATN &&other) noexcept;

/// <summary>
/// Compute the set of valid tokens that can occur starting in state {@code s}.
/// If {@code ctx} is null, the set of tokens will not include what can follow
/// the rule surrounding {@code s}. In other words, the set will be
/// restricted to tokens reachable staying within {@code s}'s rule.
/// </summary>
virtual misc::IntervalSet nextTokens(ATNState *s, RuleContext *ctx) const;
misc::IntervalSet nextTokens(ATNState *s, RuleContext *ctx) const;

/// <summary>
/// Compute the set of valid tokens that can occur starting in {@code s} and
/// staying in same rule. <seealso cref="Token#EPSILON"/> is in set if we reach end of
/// rule.
/// </summary>
virtual misc::IntervalSet const& nextTokens(ATNState *s) const;
misc::IntervalSet const& nextTokens(ATNState *s) const;

virtual void addState(ATNState *state);
void addState(ATNState *state);

virtual void removeState(ATNState *state);
void removeState(ATNState *state);

virtual int defineDecisionState(DecisionState *s);
int defineDecisionState(DecisionState *s);

virtual DecisionState *getDecisionState(size_t decision) const;
DecisionState *getDecisionState(size_t decision) const;

virtual size_t getNumberOfDecisions() const;
size_t getNumberOfDecisions() const;

/// <summary>
/// Computes the set of input symbols which could follow ATN state number
Expand All @@ -106,12 +115,17 @@ namespace atn {
/// specified state in the specified context. </returns>
/// <exception cref="IllegalArgumentException"> if the ATN does not contain a state with
/// number {@code stateNumber} </exception>
virtual misc::IntervalSet getExpectedTokens(size_t stateNumber, RuleContext *context) const;
misc::IntervalSet getExpectedTokens(size_t stateNumber, RuleContext *context) const;

std::string toString() const;

private:
friend class LexerATNSimulator;
friend class ParserATNSimulator;

mutable std::mutex _mutex;
mutable std::shared_mutex _stateMutex;
mutable std::shared_mutex _edgeMutex;
};

} // namespace atn
Expand Down
Loading