-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstock_env.cpp
152 lines (130 loc) · 4.21 KB
/
stock_env.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include "stock_env.hpp"
#include <spdlog/spdlog.h>
bool stock_env::init = false;
stock_env::stock_env(config_mgr2& xcs_config,std::shared_ptr<spdlog::logger> logger):
current_reward(0),account_(account("000001")),account_com_(account("000001",spdlog::stdout_color_mt("account_com_"))),
current_state_(0), logger_(logger),account_path(std::string("")),
account_stock_num_path(std::string("")),account_com_path(std::string("")),
account_com_stock_num_path(std::string(""))
{
if(!stock_env::init)
{
if(!xcs_config.exist(tag_name()))
{
logger_->error("class:{} method:{} msg:{}",class_name(), "constructor", "section <" + tag_name() + "> not found");
exit(1);
}
std::string filepath;
try{
filepath = (std::string)xcs_config.Value(tag_name(),"stock data file");
data_ = getData(filepath);
}
catch (const char *attribute){
std::string msg = "attribute \'" + std::string(attribute) + "\' not found in <" + tag_name() + ">";
logger_->error("class:{} method:{} msg:{}",class_name(), "constructor", msg);
exit(1);
}
}
stock_env::init = true;
}
void stock_env::begin_problem()
{
account_.addMoney(10000000);
account_com_.addMoney(10000000);
updateAccountPath(0);
current_state_ = 1;
set_input(current_state_);
}
inline void stock_env::set_input(int64_t pos) //the pos must bigger than 0
{
std::string input("");
if(getValue(data_,pos,"close") >= getValue(data_, pos-1,"close"))
{
input += "1";
}
else
{
input += "0";
}
if(getValue(data_,pos,"volume") >= getValue(data_, pos-1,"volume"))
{
input += "1";
}
else
{
input += "0";
}
if(getValue(data_,pos,"volume") >= getValue(data_, pos,"VMA20"))
{
input += "1";
}
else
{
input += "0";
}
if(getValue(data_,pos,"close") >= getValue(data_, pos,"PMA5"))
{
input += "1";
}
else
{
input += "0";
}
if(getValue(data_,pos,"close") >= getValue(data_, pos,"PMA10"))
{
input += "1";
}
else
{
input += "0";
}
inputs.set_string_value(input);
logger_->info("The {} day: {}",pos, inputs.string_value());
}
double getValue(std::shared_ptr<std::vector<std::shared_ptr<std::map<std::string, double>>>>const &data, int64_t pos, std::string const & key)
{
return (*((*data)[pos]))[key];
}
void stock_env::perform(const binary_action& action)
{
std::string action_str = action.string_value();
logger_->info("perform action: {}", action_str);
bool buy = action_str[0] == '0'? false:true;
double percent(0);
for(int i=1;i<4;i++)
{
percent = percent*2 + (action_str[i]-'0');
}
double target_percnet = percent/7;
// set the reward
double diff = getValue(data_,current_state_+1,"close")-getValue(data_,current_state_,"close");
if(percent != 0 && buy)
{
current_reward = diff * percent * account_.getMoney()/getValue(data_, current_state_,"close")/100000;
}
else if(percent != 0 && !buy)
{
current_reward = -diff*percent*account_.getStockAmount()/1000;
}
if(!buy) target_percnet = -target_percnet;
// perform action on the env
logger_->info(" buy or sell {} percent {}", buy, target_percnet);
account_.order_percent(getValue(data_, current_state_,"close"), target_percnet);
if(account_com_.getStockAmount()==0)
{
account_com_.order_target_percent(getValue(data_, current_state_,"close"),0.99);
}
logger_->info("The infomation of the account:the total value {} money {} stock number {}",
account_.getValue(getValue(data_, current_state_,"close")), account_.getMoney(),account_.getStockAmount());
logger_->info("Compare with the hold stock number {} total value {}", account_com_.getStockAmount(), account_com_.getValue(getValue(data_, current_state_,"close")));
}
void stock_env::save_state(std::ostream& output) const
{
output<<std::endl;
output<< current_state_<<std::endl;
}
void stock_env::restore_state(std::istream& input)
{
input >> current_state_;
set_input(current_state_);
}