-
Notifications
You must be signed in to change notification settings - Fork 2
/
ast_optimizer.cpp
131 lines (111 loc) · 4.52 KB
/
ast_optimizer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include "ast_optimizer.h"
namespace LangUMS
{
std::shared_ptr<IASTNode> ASTOptimizer::Process(std::shared_ptr<IASTNode> ast)
{
m_Root = std::move(ast);
auto childCount = m_Root->GetChildCount();
for (auto i = 0u; i < childCount; i++)
{
auto& child = m_Root->GetChild(i);
auto newChild = CalculateConstantExpressions(child);
newChild = ConcatenateStrings(newChild);
m_Root->SetChild(i, newChild);
}
return std::move(m_Root);
}
std::shared_ptr<IASTNode> ASTOptimizer::CalculateConstantExpressions(const std::shared_ptr<IASTNode>& node)
{
auto childCount = node->GetChildCount();
for (auto i = 0u; i < childCount; i++)
{
auto& child = node->GetChild(i);
node->SetChild(i, CalculateConstantExpressions(child));
}
if (node->GetType() == ASTNodeType::BinaryExpression)
{
auto expression = (ASTBinaryExpression*)node.get();
auto& lhs = expression->GetLHSValue();
auto& rhs = expression->GetRHSValue();
if (lhs->GetType() == ASTNodeType::NumberLiteral && rhs->GetType() == ASTNodeType::NumberLiteral)
{
auto lhsNumber = (ASTNumberLiteral*)lhs.get();
auto rhsNumber = (ASTNumberLiteral*)rhs.get();
auto result = CalculateConstantBinaryExpression(lhsNumber->GetValue(), rhsNumber->GetValue(), expression->GetOperator());
return std::shared_ptr<IASTNode>(new ASTNumberLiteral(result, node->GetCharIndex()));
}
}
else if (node->GetType() == ASTNodeType::UnaryExpression)
{
auto expression = (ASTUnaryExpression*)node.get();
auto& value = expression->GetValue();
if (value->GetType() == ASTNodeType::NumberLiteral &&
expression->GetOperator() == OperatorType::Not)
{
auto valueNumber = (ASTNumberLiteral*)value.get();
auto result = valueNumber->GetValue() > 0 ? 0 : 1;
return std::shared_ptr<IASTNode>(new ASTNumberLiteral(result, node->GetCharIndex()));
}
}
return node;
}
std::shared_ptr<IASTNode> ASTOptimizer::ConcatenateStrings(const std::shared_ptr<IASTNode>& node)
{
auto childCount = node->GetChildCount();
for (auto i = 0u; i < childCount; i++)
{
auto& child = node->GetChild(i);
node->SetChild(i, ConcatenateStrings(child));
}
if (node->GetType() != ASTNodeType::BinaryExpression)
{
return node;
}
auto expression = (ASTBinaryExpression*)node.get();
if (expression->GetOperator() != OperatorType::Add)
{
return node;
}
auto& lhs = expression->GetLHSValue();
auto& rhs = expression->GetRHSValue();
if (lhs->GetType() == ASTNodeType::StringLiteral && rhs->GetType() == ASTNodeType::StringLiteral)
{
auto lhsString = (ASTStringLiteral*)lhs.get();
auto rhsString = (ASTStringLiteral*)rhs.get();
auto result = rhsString->GetValue() + lhsString->GetValue();
return std::shared_ptr<IASTNode>(new ASTStringLiteral(result, node->GetCharIndex()));
}
return node;
}
int ASTOptimizer::CalculateConstantBinaryExpression(int left, int right, OperatorType op)
{
switch (op)
{
case OperatorType::Or:
return (left != 0 || right != 0) ? 1 : 0;
case OperatorType::And:
return (left != 0 && right != 0) ? 1 : 0;
case OperatorType::Equals:
return (left == right) ? 1 : 0;
case OperatorType::NotEquals:
return (left != right) ? 1 : 0;
case OperatorType::GreaterThan:
return (left > right) ? 1 : 0;
case OperatorType::GreaterThanOrEquals:
return (left >= right) ? 1 : 0;
case OperatorType::LessThan:
return (left < right) ? 1 : 0;
case OperatorType::LessThanOrEquals:
return (left <= right) ? 1 : 0;
case OperatorType::Add:
return left + right;
case OperatorType::Subtract:
return left - right;
case OperatorType::Divide:
return left / right;
case OperatorType::Multiply:
return left * right;
}
return 0;
}
}