Skip to content

Commit

Permalink
[C++] Add T::is for type hierarchy checks and remove some dynamic_cast
Browse files Browse the repository at this point in the history
  • Loading branch information
jcking committed Mar 30, 2022
1 parent 7e4b48b commit 91b6504
Show file tree
Hide file tree
Showing 46 changed files with 283 additions and 110 deletions.
33 changes: 15 additions & 18 deletions runtime/Cpp/runtime/src/ParserRuleContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,36 @@ void ParserRuleContext::removeLastChild() {
}
}

tree::TerminalNode* ParserRuleContext::getToken(size_t ttype, size_t i) {
tree::TerminalNode* ParserRuleContext::getToken(size_t ttype, size_t i) const {
if (i >= children.size()) {
return nullptr;
}

size_t j = 0; // what token with ttype have we found?
for (auto *o : children) {
if (o->getTreeType() == ParseTreeType::TERMINAL || o->getTreeType() == ParseTreeType::ERROR) {
tree::TerminalNode *tnode = downCast<tree::TerminalNode *>(o);
Token *symbol = tnode->getSymbol();
for (auto *child : children) {
if (TerminalNode::is(child)) {
tree::TerminalNode *typedChild = downCast<tree::TerminalNode*>(child);
Token *symbol = typedChild->getSymbol();
if (symbol->getType() == ttype) {
if (j++ == i) {
return tnode;
return typedChild;
}
}
}
}

return nullptr;
}

std::vector<tree::TerminalNode *> ParserRuleContext::getTokens(size_t ttype) {
std::vector<tree::TerminalNode *> tokens;
for (auto &o : children) {
if (o->getTreeType() == ParseTreeType::TERMINAL || o->getTreeType() == ParseTreeType::ERROR) {
tree::TerminalNode *tnode = downCast<tree::TerminalNode *>(o);
Token *symbol = tnode->getSymbol();
std::vector<tree::TerminalNode *> ParserRuleContext::getTokens(size_t ttype) const {
std::vector<tree::TerminalNode*> tokens;
for (auto *child : children) {
if (TerminalNode::is(child)) {
tree::TerminalNode *typedChild = downCast<tree::TerminalNode*>(child);
Token *symbol = typedChild->getSymbol();
if (symbol->getType() == ttype) {
tokens.push_back(tnode);
tokens.push_back(typedChild);
}
}
}

return tokens;
}

Expand All @@ -124,11 +121,11 @@ misc::Interval ParserRuleContext::getSourceInterval() {
return misc::Interval(start->getTokenIndex(), stop->getTokenIndex());
}

Token* ParserRuleContext::getStart() {
Token* ParserRuleContext::getStart() const {
return start;
}

Token* ParserRuleContext::getStop() {
Token* ParserRuleContext::getStop() const {
return stop;
}

Expand Down
40 changes: 20 additions & 20 deletions runtime/Cpp/runtime/src/ParserRuleContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ namespace antlr4 {

ParserRuleContext();
ParserRuleContext(ParserRuleContext *parent, size_t invokingStateNumber);
virtual ~ParserRuleContext() {}

/** COPY a ctx (I'm deliberately not using copy constructor) to avoid
* confusion with creating node with parent. Does not copy children
Expand All @@ -88,38 +87,39 @@ namespace antlr4 {
/// Used by enterOuterAlt to toss out a RuleContext previously added as
/// we entered a rule. If we have # label, we will need to remove
/// generic ruleContext object.
virtual void removeLastChild();
void removeLastChild();

virtual tree::TerminalNode* getToken(size_t ttype, std::size_t i);
tree::TerminalNode* getToken(size_t ttype, std::size_t i) const;

virtual std::vector<tree::TerminalNode *> getTokens(size_t ttype);
std::vector<tree::TerminalNode*> getTokens(size_t ttype) const;

template<typename T>
T* getRuleContext(size_t i) {
if (children.empty()) {
return nullptr;
}

T* getRuleContext(size_t i) const {
static_assert(std::is_base_of_v<RuleContext, T>, "T must be derived from RuleContext");
size_t j = 0; // what element have we found with ctxType?
for (auto &child : children) {
if (antlrcpp::is<T *>(child)) {
if (j++ == i) {
return dynamic_cast<T *>(child);
for (auto *child : children) {
if (RuleContext::is(child)) {
if (auto *typedChild = dynamic_cast<T*>(child); typedChild != nullptr) {
if (j++ == i) {
return typedChild;
}
}
}
}
return nullptr;
}

template<typename T>
std::vector<T *> getRuleContexts() {
std::vector<T *> contexts;
std::vector<T*> getRuleContexts() const {
static_assert(std::is_base_of_v<RuleContext, T>, "T must be derived from RuleContext");
std::vector<T*> contexts;
for (auto *child : children) {
if (antlrcpp::is<T *>(child)) {
contexts.push_back(dynamic_cast<T *>(child));
if (RuleContext::is(child)) {
if (auto *typedChild = dynamic_cast<T*>(child); typedChild != nullptr) {
contexts.push_back(typedChild);
}
}
}

return contexts;
}

Expand All @@ -130,14 +130,14 @@ namespace antlr4 {
* Note that the range from start to stop is inclusive, so for rules that do not consume anything
* (for example, zero length or error productions) this token may exceed stop.
*/
virtual Token *getStart();
Token* getStart() const;

/**
* Get the final token in this context.
* Note that the range from start to stop is inclusive, so for rules that do not consume anything
* (for example, zero length or error productions) this token may precede start.
*/
virtual Token *getStop();
Token* getStop() const;

/// <summary>
/// Used for rule context info debugging during parse-time, not so much for ATN debugging </summary>
Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/RuleContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ namespace antlr4 {
*/
class ANTLR4CPP_PUBLIC RuleContext : public tree::ParseTree {
public:
static bool is(const tree::ParseTree &parseTree) { return parseTree.getTreeType() == tree::ParseTreeType::RULE; }

static bool is(const tree::ParseTree *parseTree) { return parseTree != nullptr && is(*parseTree); }

/// What state invoked the rule associated with this context?
/// The "return address" is the followState of invokingState
/// If parent is null, this should be -1 and this context object represents the start rule.
Expand Down
54 changes: 27 additions & 27 deletions runtime/Cpp/runtime/src/atn/ATNDeserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ namespace {
*/
void markPrecedenceDecisions(const ATN &atn) {
for (ATNState *state : atn.states) {
if (!is<StarLoopEntryState*>(state)) {
if (!StarLoopEntryState::is(state)) {
continue;
}

Expand All @@ -92,8 +92,8 @@ namespace {
*/
if (atn.ruleToStartState[state->ruleIndex]->isLeftRecursiveRule) {
ATNState *maybeLoopEndState = state->transitions[state->transitions.size() - 1]->target;
if (is<LoopEndState *>(maybeLoopEndState)) {
if (maybeLoopEndState->epsilonOnlyTransitions && is<RuleStopState*>(maybeLoopEndState->transitions[0]->target)) {
if (LoopEndState::is(maybeLoopEndState)) {
if (maybeLoopEndState->epsilonOnlyTransitions && RuleStopState::is(maybeLoopEndState->transitions[0]->target)) {
downCast<StarLoopEntryState*>(state)->isPrecedenceDecision = true;
}
}
Expand Down Expand Up @@ -291,7 +291,7 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
if (stype == ATNStateType::LOOP_END) { // special case
int loopBackStateNumber = data[p++];
loopBackStateNumbers.push_back({ downCast<LoopEndState*>(s), loopBackStateNumber });
} else if (is<BlockStartState*>(s)) {
} else if (BlockStartState::is(s)) {
int endStateNumber = data[p++];
endStateNumbers.push_back({ downCast<BlockStartState*>(s), endStateNumber });
}
Expand Down Expand Up @@ -340,7 +340,7 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da

atn->ruleToStopState.resize(nrules);
for (ATNState *state : atn->states) {
if (!is<RuleStopState*>(state)) {
if (!RuleStopState::is(state)) {
continue;
}

Expand Down Expand Up @@ -389,7 +389,7 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
for (ATNState *state : atn->states) {
for (size_t i = 0; i < state->transitions.size(); i++) {
const Transition *t = state->transitions[i].get();
if (!is<const RuleTransition*>(t)) {
if (!RuleTransition::is(t)) {
continue;
}

Expand All @@ -407,7 +407,7 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
}

for (ATNState *state : atn->states) {
if (is<BlockStartState*>(state)) {
if (BlockStartState::is(state)) {
BlockStartState *startState = downCast<BlockStartState*>(state);

// we need to know the end state to set its start state
Expand All @@ -423,19 +423,19 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
startState->endState->startState = downCast<BlockStartState*>(state);
}

if (is<PlusLoopbackState*>(state)) {
if (PlusLoopbackState::is(state)) {
PlusLoopbackState *loopbackState = downCast<PlusLoopbackState*>(state);
for (size_t i = 0; i < loopbackState->transitions.size(); i++) {
ATNState *target = loopbackState->transitions[i]->target;
if (is<PlusBlockStartState*>(target)) {
if (PlusBlockStartState::is(target)) {
(downCast<PlusBlockStartState*>(target))->loopBackState = loopbackState;
}
}
} else if (is<StarLoopbackState*>(state)) {
} else if (StarLoopbackState::is(state)) {
StarLoopbackState *loopbackState = downCast<StarLoopbackState*>(state);
for (size_t i = 0; i < loopbackState->transitions.size(); i++) {
ATNState *target = loopbackState->transitions[i]->target;
if (is<StarLoopEntryState *>(target)) {
if (StarLoopEntryState::is(target)) {
downCast<StarLoopEntryState*>(target)->loopBackState = loopbackState;
}
}
Expand Down Expand Up @@ -506,16 +506,16 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
continue;
}

if (!is<StarLoopEntryState*>(state)) {
if (!StarLoopEntryState::is(state)) {
continue;
}

ATNState *maybeLoopEndState = state->transitions[state->transitions.size() - 1]->target;
if (!is<LoopEndState*>(maybeLoopEndState)) {
if (!LoopEndState::is(maybeLoopEndState)) {
continue;
}

if (maybeLoopEndState->epsilonOnlyTransitions && is<RuleStopState*>(maybeLoopEndState->transitions[0]->target)) {
if (maybeLoopEndState->epsilonOnlyTransitions && RuleStopState::(maybeLoopEndState->transitions[0]->target)) {
endState = state;
break;
}
Expand Down Expand Up @@ -578,52 +578,52 @@ void ATNDeserializer::verifyATN(const ATN &atn) const {

checkCondition(state->epsilonOnlyTransitions || state->transitions.size() <= 1);

if (is<PlusBlockStartState*>(state)) {
if (PlusBlockStartState::is(state)) {
checkCondition((downCast<PlusBlockStartState*>(state))->loopBackState != nullptr);
}

if (is<StarLoopEntryState*>(state)) {
if (StarLoopEntryState::is(state)) {
StarLoopEntryState *starLoopEntryState = downCast<StarLoopEntryState*>(state);
checkCondition(starLoopEntryState->loopBackState != nullptr);
checkCondition(starLoopEntryState->transitions.size() == 2);

if (is<StarBlockStartState*>(starLoopEntryState->transitions[0]->target)) {
if (StarBlockStartState::is(starLoopEntryState->transitions[0]->target)) {
checkCondition(downCast<LoopEndState*>(starLoopEntryState->transitions[1]->target) != nullptr);
checkCondition(!starLoopEntryState->nonGreedy);
} else if (is<LoopEndState*>(starLoopEntryState->transitions[0]->target)) {
checkCondition(is<StarBlockStartState*>(starLoopEntryState->transitions[1]->target));
} else if (LoopEndState::is(starLoopEntryState->transitions[0]->target)) {
checkCondition(StarBlockStartState::is(starLoopEntryState->transitions[1]->target));
checkCondition(starLoopEntryState->nonGreedy);
} else {
throw IllegalStateException();
}
}

if (is<StarLoopbackState*>(state)) {
if (StarLoopbackState::is(state)) {
checkCondition(state->transitions.size() == 1);
checkCondition(is<StarLoopEntryState*>(state->transitions[0]->target));
checkCondition(StarLoopEntryState::is(state->transitions[0]->target));
}

if (is<LoopEndState*>(state)) {
if (LoopEndState::is(state)) {
checkCondition((downCast<LoopEndState*>(state))->loopBackState != nullptr);
}

if (is<RuleStartState*>(state)) {
if (RuleStartState::is(state)) {
checkCondition((downCast<RuleStartState*>(state))->stopState != nullptr);
}

if (is<BlockStartState*>(state)) {
if (BlockStartState::is(state)) {
checkCondition((downCast<BlockStartState*>(state))->endState != nullptr);
}

if (is<BlockEndState*>(state)) {
if (BlockEndState::is(state)) {
checkCondition((downCast<BlockEndState*>(state))->startState != nullptr);
}

if (is<DecisionState*>(state)) {
if (DecisionState::is(state)) {
DecisionState *decisionState = downCast<DecisionState*>(state);
checkCondition(decisionState->transitions.size() <= 1 || decisionState->decision >= 0);
} else {
checkCondition(state->transitions.size() <= 1 || is<RuleStopState*>(state));
checkCondition(state->transitions.size() <= 1 || RuleStopState::is(state));
}
}
}
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/ActionTransition.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ namespace atn {

class ANTLR4CPP_PUBLIC ActionTransition final : public Transition {
public:
static bool is(const Transition &transition) { return transition.getTransitionType() == TransitionType::ACTION; }

static bool is(const Transition *transition) { return transition != nullptr && is(*transition); }

const size_t ruleIndex;
const size_t actionIndex;
const bool isCtxDependent; // e.g., $i ref in action
Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/ArrayPredictionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ namespace atn {

class ANTLR4CPP_PUBLIC ArrayPredictionContext final : public PredictionContext {
public:
static bool is(const PredictionContext &predictionContext) { return predictionContext.getContextType() == PredictionContextType::ARRAY; }

static bool is(const PredictionContext *predictionContext) { return predictionContext != nullptr && is(*predictionContext); }

/// Parent can be empty only if full ctx mode and we make an array
/// from EMPTY and non-empty. We merge EMPTY by using null parent and
/// returnState == EMPTY_RETURN_STATE.
Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/AtomTransition.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ namespace atn {
/// TODO: make all transitions sets? no, should remove set edges.
class ANTLR4CPP_PUBLIC AtomTransition final : public Transition {
public:
static bool is(const Transition &transition) { return transition.getTransitionType() == TransitionType::ATOM; }

static bool is(const Transition *transition) { return transition != nullptr && is(*transition); }

/// The token type or character value; or, signifies special label.
/// TODO: rename this to label
const size_t _label;
Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/BasicBlockStartState.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ namespace atn {

class ANTLR4CPP_PUBLIC BasicBlockStartState final : public BlockStartState {
public:
static bool is(const ATNState &atnState) { return atnState.getStateType() == ATNStateType::BLOCK_START; }

static bool is(const ATNState *atnState) { return atnState != nullptr && is(*atnState); }

BasicBlockStartState() : BlockStartState(ATNStateType::BLOCK_START) {}
};

Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/BasicState.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ namespace atn {

class ANTLR4CPP_PUBLIC BasicState final : public ATNState {
public:
static bool is(const ATNState &atnState) { return atnState.getStateType() == ATNStateType::BASIC; }

static bool is(const ATNState *atnState) { return atnState != nullptr && is(*atnState); }

BasicState() : ATNState(ATNStateType::BASIC) {}
};

Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/BlockEndState.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ namespace atn {
/// Terminal node of a simple {@code (a|b|c)} block.
class ANTLR4CPP_PUBLIC BlockEndState final : public ATNState {
public:
static bool is(const ATNState &atnState) { return atnState.getStateType() == ATNStateType::BLOCK_END; }

static bool is(const ATNState *atnState) { return atnState != nullptr && is(*atnState); }

BlockStartState *startState = nullptr;

BlockEndState() : ATNState(ATNStateType::BLOCK_END) {}
Expand Down
Loading

0 comments on commit 91b6504

Please sign in to comment.