-
Notifications
You must be signed in to change notification settings - Fork 4
/
qlearning.h
71 lines (52 loc) · 1.77 KB
/
qlearning.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#ifndef QLEARNING_TABLE_H
#define QLEARNING_TABLE_H
#include "tiles.h"
#include "ns3/double.h"
#include "ns3/random-variable-stream.h"
using namespace std;
using namespace ns3;
#define NDIM 2
class QlearningTable
{
public:
QlearningTable(int actions,
double learning_rate,
double reward_decay,
double e_greddy,
char cs_id,
int ntiling,
int ntiles,
double learning_decay,
double explore_decay,
double lambda
);
int choose_action(double stateVec[]);
int choose_best(double stateVec[]);
int choose_random();
void qlearning_update(double lastStateVec[], int last_action, double reward, double newStateVec[]);
double update_sarsa_lambda_before(double lastStateVec[], int last_action, double reward);
void update_sarsa_lambda_after(double newStateVec[], int new_action, double TDerror);
void update_sarsa_lambda_terminal(double TDerror);
double update_q_lambda_before(double lastStateVec[], int last_action, double reward);
void update_q_lambda_after(double newStateVec[], int new_action, double TDerror);
double calculate_action_value_q_estimate(double stateVec[], int act_id, int tiles_array[]);
void set_parameter();
int save_table();
int load_table();
void check_qtable();
private:
int m_actions;
double m_learning_rate;
double m_reward_decay;
double m_e_greedy;
double m_lambda; // lambda for eligibility trace decay
int m_ntiling;
int m_ntiles;
std::vector<double> m_weights;
std::vector<double> m_traces;
double m_learning_decay;
double m_explore_decay;
std::string m_fname;
Ptr<UniformRandomVariable> m_rand_probability;
};
#endif // QLEARNING_TABLE_H