-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcassandra-driver.cpp
151 lines (130 loc) · 5.18 KB
/
cassandra-driver.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include "cassandra-driver.h"
#include "cassandra-parser.h"
const double epsilon = 0.000001L;
CassDriver::CassDriver(BWC::POMDP* p) :
trace_scanning(false), trace_parsing(false), pomdp(p),
has_det_obs(true) { }
int CassDriver::parse(const std::string &f) {
this->file = f;
beginScan();
yy::CassParser parser(*this);
parser.set_debug_level(this->trace_parsing);
int res = parser.parse();
endScan();
// if the observations are not deterministic, we should make them
if (!this->has_det_obs) {
assert(this->pomdp->isValidMdp());
this->pomdp->makeObsDet();
}
return res;
}
void CassDriver::error(const yy::location &l, const std::string &m) {
std::cerr << l << ": " << m << std::endl;
}
void CassDriver::error(const std::string &m) {
std::cerr << "error: " << m << std::endl;
}
void CassDriver::addTransition(ElemRef source, ElemRef action, ElemRef target,
double prob) {
assert(source.type == ELEMREFTYPE_NAME);
assert(target.type == ELEMREFTYPE_NAME);
assert(action.type != ELEMREFTYPE_ID);
std::vector<int>::iterator a_it;
std::vector<int>::iterator a_end;
if (action.type == ELEMREFTYPE_NAME) {
std::vector<int> dummy_act_vector;
dummy_act_vector.push_back(this->pomdp->getActionId(action.name));
a_it = dummy_act_vector.begin();
a_end = dummy_act_vector.end();
} else if (action.type == ELEMREFTYPE_ALL) {
std::vector<int> dummy_act_vector;
for (int i = 0; i < this->pomdp->getActionCount(); i++)
dummy_act_vector.push_back(i);
a_it = dummy_act_vector.begin();
a_end = dummy_act_vector.end();
}
for (; a_it != a_end; ++a_it) {
this->pomdp->addTransition(this->pomdp->getStateId(source.name),
*a_it,
this->pomdp->getStateId(target.name),
prob);
}
}
void addWeightHelper2(BWC::POMDP* p, int source,
int action, ElemRef target, double weight) {
if (target.type == ELEMREFTYPE_ALL)
for (int i = 0; i < p->getStateCount(); i++)
p->weightTransition(source, action, i, weight);
else
p->weightTransition(source, action,
p->getStateId(target.name),
weight);
}
void addWeightHelper1(BWC::POMDP* p, int source,
ElemRef action, ElemRef target, double weight) {
if (action.type == ELEMREFTYPE_ALL)
for (int i = 0; i < p->getActionCount(); i++)
addWeightHelper2(p, source, i, target, weight);
else
addWeightHelper2(p, source, p->getActionId(action.name),
target, weight);
}
void CassDriver::addWeight(ElemRef source, ElemRef action, ElemRef target,
ElemRef obs, double weight) {
assert(obs.type == ELEMREFTYPE_ALL);
assert(source.type != ELEMREFTYPE_ID);
assert(action.type != ELEMREFTYPE_ID);
assert(target.type != ELEMREFTYPE_ID);
if (source.type == ELEMREFTYPE_ALL)
for (int i = 0; i < this->pomdp->getStateCount(); i++)
addWeightHelper1(this->pomdp, i, action, target,
weight * this->weight_sign);
else
addWeightHelper1(this->pomdp, this->pomdp->getStateId(source.name),
action, target, weight * this->weight_sign);
}
void CassDriver::addObsTransition(ElemRef action, ElemRef target, ElemRef obs,
double prob) {
assert(action.type == ELEMREFTYPE_ALL);
assert(target.type == ELEMREFTYPE_NAME);
assert(obs.type == ELEMREFTYPE_NAME);
this->pomdp->addObservationProb(this->pomdp->getStateId(target.name),
this->pomdp->getObservationId(obs.name),
prob);
this->has_det_obs = this->has_det_obs &&
((std::abs(prob - 1.0L) < epsilon) ||
(std::abs(prob) < epsilon));
}
void CassDriver::setDiscount(double discount) {
this->pomdp->setDiscFactor(discount);
}
void CassDriver::setWeightSign(int sign) {
this->weight_sign = sign;
}
void CassDriver::setStates(std::vector<std::string> states) {
this->pomdp->setStates(states);
}
void CassDriver::setActions(std::vector<std::string> actions) {
this->pomdp->setActions(actions);
}
void CassDriver::setObservations(std::vector<std::string> observations) {
this->pomdp->setObservations(observations);
}
void CassDriver::setInitialDist(std::vector<ElemRef> states) {
std::map<int, double> initial_dist;
double denominator = states.size();
for (std::vector<ElemRef>::iterator i = states.begin();
i != states.end(); ++i) {
assert(i->type == ELEMREFTYPE_NAME);
initial_dist[this->pomdp->getStateId(i->name)] = 1.0/denominator;
}
this->pomdp->setInitialDist(initial_dist);
}
void CassDriver::setInitialDist(std::vector<double> probs) {
assert(probs.size() == this->pomdp->getStateCount());
std::map<int, double> initial_dist;
for (int i = 0; i < probs.size(); i++)
if (probs[i] > 0.0)
initial_dist[i] = probs[i];
this->pomdp->setInitialDist(initial_dist);
}